[llvm] [NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax (PR #159387)

Drew Kersnar via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 17 14:35:30 PDT 2025


https://github.com/dakersnar updated https://github.com/llvm/llvm-project/pull/159387

>From 196ca9b3c763acbcaa5c7cfb2455a9dad6f18a4b Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:30:38 +0000
Subject: [PATCH 1/3] [NVPTX] Lower LLVM masked vector stores to PTX using the
 new sink symbol syntax

---
 .../llvm/Analysis/TargetTransformInfo.h       |   8 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.cpp    |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.h      |   2 +-
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  10 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |   2 +
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  89 ++++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  10 +-
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp |  26 ++
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |   3 +
 .../Target/RISCV/RISCVTargetTransformInfo.h   |   2 +-
 llvm/lib/Target/VE/VETargetTransformInfo.h    |   2 +-
 .../lib/Target/X86/X86TargetTransformInfo.cpp |   2 +-
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |   2 +-
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |   3 +-
 .../NVPTX/masked-store-variable-mask.ll       |  56 +++
 .../CodeGen/NVPTX/masked-store-vectors-256.ll | 319 ++++++++++++++++++
 20 files changed, 530 insertions(+), 18 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store.
+  /// Return true if the target supports masked store. A value of false for
+  /// IsMaskConstant indicates that the mask could either be variable or
+  /// constant. This is for targets that only support masked store with a
+  /// constant mask.
   LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+                                   unsigned AddressSpace,
+                                   bool IsMaskConstant = false) const;
   /// Return true if the target supports masked load.
   LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                                   unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace, bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace, bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                          unsigned AddressSpace) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
+                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    if (Mask->getConstantOperandVal(I) == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
   case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace, bool IsMaskConstant) const {
+
+  if (!IsMaskConstant)
+    return false;
+  
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ScalarTy = VTy->getScalarType();
+  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+    return true;
+
+  return false;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddrSpace, bool IsMaskConstant) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace, bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(3))))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT:    and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT:    setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT:    ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT:    and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT:    setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT:    ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT:    ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT:    not.pred %p5, %p1;
+; CHECK-NEXT:    @%p5 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %cond.store
+; CHECK-NEXT:    st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT:  $L__BB0_2: // %else
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT:    not.pred %p6, %p2;
+; CHECK-NEXT:    @%p6 bra $L__BB0_4;
+; CHECK-NEXT:  // %bb.3: // %cond.store1
+; CHECK-NEXT:    st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT:  $L__BB0_4: // %else2
+; CHECK-NEXT:    setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT:    not.pred %p7, %p3;
+; CHECK-NEXT:    @%p7 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.5: // %cond.store3
+; CHECK-NEXT:    st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT:  $L__BB0_6: // %else4
+; CHECK-NEXT:    not.pred %p8, %p4;
+; CHECK-NEXT:    @%p8 bra $L__BB0_8;
+; CHECK-NEXT:  // %bb.7: // %cond.store5
+; CHECK-NEXT:    st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT:  $L__BB0_8: // %else6
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..0935bf80b04be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,319 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+;   this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr %a
+  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr %a
+  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x float>, ptr %a
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x double>, ptr %a
+  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x float>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x double>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT:    st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT:    st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK:       {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT:    st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT:    st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT:    ret;
+  %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+  %load05 = extractelement <8 x i32> %1, i32 0
+  %load16 = extractelement <8 x i32> %1, i32 1
+  %load38 = extractelement <8 x i32> %1, i32 3
+  %load49 = extractelement <8 x i32> %1, i32 4
+  %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+  %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+  %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+  %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+  %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+  %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+  %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+  %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)

>From 861d8126795c859b71c2e27ee8c28f9c2c450503 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:52:13 +0000
Subject: [PATCH 2/3] Clang format

---
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h   | 3 ++-
 llvm/lib/Analysis/TargetTransformInfo.cpp              | 6 ++++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h           | 4 ++--
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 3 ++-
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp     | 5 +++--
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h       | 4 ++--
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h       | 3 ++-
 llvm/lib/Target/VE/VETargetTransformInfo.h             | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.cpp         | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.h           | 3 ++-
 12 files changed, 27 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 33705e1dd5f98..267ada0e3c76a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace, bool IsMaskConstant) const {
+                                  unsigned AddressSpace,
+                                  bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 838712e55d0dd..b37f7969fc792 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace, bool IsMaskConstant) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
+                                             unsigned AddressSpace,
+                                             bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+                                     IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e40631d88748c..669d9e2ae1dad 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ee4f72552d90d..3c5b2195d8dcc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -189,8 +189,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
   bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
                          unsigned AddressSpace) const override;
 
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
+  bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index c989bf77a9d51..74df572ef1521 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,8 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
+                                        unsigned /*AddressSpace*/,
+                                        bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e2674bb9cdad7..195cc616347b0 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,8 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 88b13cb38d67b..045e9d2b099f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -598,11 +598,12 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
 }
 
 bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                      unsigned AddrSpace, bool IsMaskConstant) const {
+                                      unsigned AddrSpace,
+                                      bool IsMaskConstant) const {
 
   if (!IsMaskConstant)
     return false;
-  
+
   //  We currently only support this feature for 256-bit vectors, so the
   //  alignment must be at least 32
   if (Alignment < 32)
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e5500966fe10..9b353be7768c6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,8 +181,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddrSpace, bool IsMaskConstant) const override;
+  bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+                          bool IsMaskConstant) const override;
 
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 80f10eb29bca4..8838f881ca3db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 4971d9148b747..3f33760a5ead9 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,8 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index b16a2a593df03..851009f89e196 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace, bool IsMaskConstant) const {
+                                    unsigned AddressSpace,
+                                    bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 7f6ff65d427ed..c7c3d93600878 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,8 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,

>From 34b9cd8c7ef9e51521fc3e5b408dea4977c1faab Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Fri, 17 Oct 2025 21:35:04 +0000
Subject: [PATCH 3/3] Update TTI to add enum value representing MaskKind to
 isLegalMaskedLoad/Store

---
 .../llvm/Analysis/TargetTransformInfo.h       | 22 +++++++++++--------
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  5 +++--
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 10 +++++----
 .../AArch64/AArch64TargetTransformInfo.h      |  5 +++--
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |  3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |  8 +++----
 .../Hexagon/HexagonTargetTransformInfo.cpp    |  5 +++--
 .../Hexagon/HexagonTargetTransformInfo.h      |  6 ++---
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 16 +++++++++++---
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  5 ++++-
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  5 +++--
 llvm/lib/Target/VE/VETargetTransformInfo.h    | 11 +++++-----
 .../lib/Target/X86/X86TargetTransformInfo.cpp |  5 +++--
 llvm/lib/Target/X86/X86TargetTransformInfo.h  | 13 ++++++-----
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |  9 ++++++--
 15 files changed, 81 insertions(+), 47 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index e7886537379bc..ea0dae5cf39ba 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,16 +810,20 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store. A value of false for
-  /// IsMaskConstant indicates that the mask could either be variable or
-  /// constant. This is for targets that only support masked store with a
-  /// constant mask.
-  LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace,
-                                   bool IsMaskConstant = false) const;
+  /// Some targets only support masked load/store with a constant mask.
+  enum MaskKind {
+    VariableOrConstantMask,
+    ConstantMask,
+  };
+
+  /// Return true if the target supports masked store.
+  LLVM_ABI bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     MaskKind MaskKind = VariableOrConstantMask) const;
   /// Return true if the target supports masked load.
-  LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const;
+  LLVM_ABI bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    MaskKind MaskKind = VariableOrConstantMask) const;
 
   /// Return true if the target supports nontemporal store.
   LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 267ada0e3c76a..da6e47c34da33 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -310,12 +310,13 @@ class TargetTransformInfoImplBase {
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
                                   unsigned AddressSpace,
-                                  bool IsMaskConstant) const {
+                                  TTI::MaskKind MaskKind) const {
     return false;
   }
 
   virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                 unsigned AddressSpace) const {
+                                 unsigned AddressSpace,
+                                 TTI::MaskKind MaskKind) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b37f7969fc792..c22f8c9e0145c 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,14 +468,16 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
                                              unsigned AddressSpace,
-                                             bool IsMaskConstant) const {
+                                             TTI::MaskKind MaskKind) const {
   return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
-                                     IsMaskConstant);
+                                     MaskKind);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                            unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
+                                            unsigned AddressSpace,
+                                            TTI::MaskKind MaskKind) const {
+  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
+                                    MaskKind);
 }
 
 bool TargetTransformInfo::isLegalNTStore(Type *DataType,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 669d9e2ae1dad..4fda311738a90 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -316,13 +316,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 9b250e6cac3ab..7ee7e9a953f5c 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
 }
 
 bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned /*AddressSpace*/) const {
+                                   unsigned /*AddressSpace*/,
+                                   TTI::MaskKind /*MaskKind*/) const {
   if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
     return false;
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 3c5b2195d8dcc..401f5ea044dab 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,12 +186,12 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool isProfitableLSRChainElement(Instruction *I) const override;
 
-  bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                         unsigned AddressSpace) const override;
+  bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
-                          bool /*IsMaskConstant*/) const override {
-    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
+                          TTI::MaskKind MaskKind) const override {
+    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
   }
 
   bool forceScalarizeMaskedGather(VectorType *VTy,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 74df572ef1521..da006ff037eeb 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -342,14 +342,15 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
                                         unsigned /*AddressSpace*/,
-                                        bool /*IsMaskConstant*/) const {
+                                        TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
 }
 
 bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
-                                       unsigned /*AddressSpace*/) const {
+                                       unsigned /*AddressSpace*/,
+                                       TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 195cc616347b0..1ca7408d75126 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -167,9 +167,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned AddressSpace,
-                          bool IsMaskConstant) const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
+                          TTI::MaskKind MaskKind) const override;
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   /// @}
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 045e9d2b099f8..bc186d86fb0a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -599,9 +599,8 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
 
 bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
                                       unsigned AddrSpace,
-                                      bool IsMaskConstant) const {
-
-  if (!IsMaskConstant)
+                                      TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
     return false;
 
   //  We currently only support this feature for 256-bit vectors, so the
@@ -624,6 +623,17 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
   return false;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+                                     unsigned /*AddrSpace*/,
+                                     TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
+    return false;
+
+  if (Alignment < DL.getTypeStoreSize(DataTy))
+    return false;
+  return true;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9b353be7768c6..b690038f85ee2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -182,7 +182,10 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
                                   Intrinsic::ID IID) const override;
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
-                          bool IsMaskConstant) const override;
+                          TTI::MaskKind MaskKind) const override;
+
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8838f881ca3db..4887ab1fb26bc 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -283,12 +283,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 3f33760a5ead9..eed3832c9f1fb 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -134,13 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
   }
 
   // Load & Store {
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+                    TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+  bool isLegalMaskedStore(
+      Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+      TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 851009f89e196..60f148f9f747e 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6317,7 +6317,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned AddressSpace) const {
+                                   unsigned AddressSpace,
+                                   TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
@@ -6331,7 +6332,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
                                     unsigned AddressSpace,
-                                    bool IsMaskConstant) const {
+                                    TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index c7c3d93600878..02b642f599270 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -268,11 +268,14 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2) const override;
   bool canMacroFuseCmp() const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace,
-                          bool IsMaskConstant = false) const override;
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    TTI::MaskKind MaskKind =
+                        TTI::MaskKind::VariableOrConstantMask) const override;
+  bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     TTI::MaskKind MaskKind =
+                         TTI::MaskKind::VariableOrConstantMask) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 412c1b04cdf3a..d5b4e03513fc8 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1128,7 +1128,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(),
               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(0)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(2))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
@@ -1138,7 +1141,9 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
                   ->getAddressSpace(),
-              isConstantIntVector(CI->getArgOperand(3))))
+              isConstantIntVector(CI->getArgOperand(3))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;



More information about the llvm-commits mailing list