[llvm] [NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax (PR #159387)

Drew Kersnar via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 17 08:52:38 PDT 2025


https://github.com/dakersnar updated https://github.com/llvm/llvm-project/pull/159387

>From 196ca9b3c763acbcaa5c7cfb2455a9dad6f18a4b Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:30:38 +0000
Subject: [PATCH 1/2] [NVPTX] Lower LLVM masked vector stores to PTX using the
 new sink symbol syntax

---
 .../llvm/Analysis/TargetTransformInfo.h       |   8 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.cpp    |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.h      |   2 +-
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  10 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |   2 +
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  89 ++++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  10 +-
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp |  26 ++
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |   3 +
 .../Target/RISCV/RISCVTargetTransformInfo.h   |   2 +-
 llvm/lib/Target/VE/VETargetTransformInfo.h    |   2 +-
 .../lib/Target/X86/X86TargetTransformInfo.cpp |   2 +-
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |   2 +-
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |   3 +-
 .../NVPTX/masked-store-variable-mask.ll       |  56 +++
 .../CodeGen/NVPTX/masked-store-vectors-256.ll | 319 ++++++++++++++++++
 20 files changed, 530 insertions(+), 18 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store.
+  /// Return true if the target supports masked store. A value of false for
+  /// IsMaskConstant indicates that the mask could either be variable or
+  /// constant. This is for targets that only support masked store with a
+  /// constant mask.
   LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+                                   unsigned AddressSpace,
+                                   bool IsMaskConstant = false) const;
   /// Return true if the target supports masked load.
   LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                                   unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace, bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace, bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                          unsigned AddressSpace) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
+                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    if (Mask->getConstantOperandVal(I) == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
   case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace, bool IsMaskConstant) const {
+
+  if (!IsMaskConstant)
+    return false;
+  
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ScalarTy = VTy->getScalarType();
+  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+    return true;
+
+  return false;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddrSpace, bool IsMaskConstant) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace, bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(3))))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT:    and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT:    setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT:    ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT:    and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT:    setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT:    ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT:    ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT:    not.pred %p5, %p1;
+; CHECK-NEXT:    @%p5 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %cond.store
+; CHECK-NEXT:    st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT:  $L__BB0_2: // %else
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT:    not.pred %p6, %p2;
+; CHECK-NEXT:    @%p6 bra $L__BB0_4;
+; CHECK-NEXT:  // %bb.3: // %cond.store1
+; CHECK-NEXT:    st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT:  $L__BB0_4: // %else2
+; CHECK-NEXT:    setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT:    not.pred %p7, %p3;
+; CHECK-NEXT:    @%p7 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.5: // %cond.store3
+; CHECK-NEXT:    st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT:  $L__BB0_6: // %else4
+; CHECK-NEXT:    not.pred %p8, %p4;
+; CHECK-NEXT:    @%p8 bra $L__BB0_8;
+; CHECK-NEXT:  // %bb.7: // %cond.store5
+; CHECK-NEXT:    st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT:  $L__BB0_8: // %else6
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..0935bf80b04be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,319 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+;   this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr %a
+  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr %a
+  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x float>, ptr %a
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x double>, ptr %a
+  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x float>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x double>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT:    st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT:    st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK:       {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT:    st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT:    st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT:    ret;
+  %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+  %load05 = extractelement <8 x i32> %1, i32 0
+  %load16 = extractelement <8 x i32> %1, i32 1
+  %load38 = extractelement <8 x i32> %1, i32 3
+  %load49 = extractelement <8 x i32> %1, i32 4
+  %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+  %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+  %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+  %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+  %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+  %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+  %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+  %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)

>From 861d8126795c859b71c2e27ee8c28f9c2c450503 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:52:13 +0000
Subject: [PATCH 2/2] Clang format

---
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h   | 3 ++-
 llvm/lib/Analysis/TargetTransformInfo.cpp              | 6 ++++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h           | 4 ++--
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 3 ++-
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp     | 5 +++--
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h       | 4 ++--
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h       | 3 ++-
 llvm/lib/Target/VE/VETargetTransformInfo.h             | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.cpp         | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.h           | 3 ++-
 12 files changed, 27 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 33705e1dd5f98..267ada0e3c76a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace, bool IsMaskConstant) const {
+                                  unsigned AddressSpace,
+                                  bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 838712e55d0dd..b37f7969fc792 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace, bool IsMaskConstant) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
+                                             unsigned AddressSpace,
+                                             bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+                                     IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e40631d88748c..669d9e2ae1dad 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ee4f72552d90d..3c5b2195d8dcc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -189,8 +189,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
   bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
                          unsigned AddressSpace) const override;
 
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
+  bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index c989bf77a9d51..74df572ef1521 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,8 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
+                                        unsigned /*AddressSpace*/,
+                                        bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e2674bb9cdad7..195cc616347b0 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,8 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 88b13cb38d67b..045e9d2b099f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -598,11 +598,12 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
 }
 
 bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                      unsigned AddrSpace, bool IsMaskConstant) const {
+                                      unsigned AddrSpace,
+                                      bool IsMaskConstant) const {
 
   if (!IsMaskConstant)
     return false;
-  
+
   //  We currently only support this feature for 256-bit vectors, so the
   //  alignment must be at least 32
   if (Alignment < 32)
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e5500966fe10..9b353be7768c6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,8 +181,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddrSpace, bool IsMaskConstant) const override;
+  bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+                          bool IsMaskConstant) const override;
 
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 80f10eb29bca4..8838f881ca3db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 4971d9148b747..3f33760a5ead9 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,8 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index b16a2a593df03..851009f89e196 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace, bool IsMaskConstant) const {
+                                    unsigned AddressSpace,
+                                    bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 7f6ff65d427ed..c7c3d93600878 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,8 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,



More information about the llvm-commits mailing list