[llvm] [NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax (PR #159387)
Drew Kersnar via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 17 08:52:38 PDT 2025
https://github.com/dakersnar updated https://github.com/llvm/llvm-project/pull/159387
>From 196ca9b3c763acbcaa5c7cfb2455a9dad6f18a4b Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:30:38 +0000
Subject: [PATCH 1/2] [NVPTX] Lower LLVM masked vector stores to PTX using the
new sink symbol syntax
---
.../llvm/Analysis/TargetTransformInfo.h | 8 +-
.../llvm/Analysis/TargetTransformInfoImpl.h | 2 +-
llvm/lib/Analysis/TargetTransformInfo.cpp | 4 +-
.../AArch64/AArch64TargetTransformInfo.h | 2 +-
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 2 +-
.../Hexagon/HexagonTargetTransformInfo.cpp | 2 +-
.../Hexagon/HexagonTargetTransformInfo.h | 2 +-
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 10 +
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.h | 2 +
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 89 ++++-
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 10 +-
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 26 ++
.../Target/NVPTX/NVPTXTargetTransformInfo.h | 3 +
.../Target/RISCV/RISCVTargetTransformInfo.h | 2 +-
llvm/lib/Target/VE/VETargetTransformInfo.h | 2 +-
.../lib/Target/X86/X86TargetTransformInfo.cpp | 2 +-
llvm/lib/Target/X86/X86TargetTransformInfo.h | 2 +-
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 3 +-
.../NVPTX/masked-store-variable-mask.ll | 56 +++
.../CodeGen/NVPTX/masked-store-vectors-256.ll | 319 ++++++++++++++++++
20 files changed, 530 insertions(+), 18 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
- /// Return true if the target supports masked store.
+ /// Return true if the target supports masked store. A value of false for
+ /// IsMaskConstant indicates that the mask could either be variable or
+ /// constant. This is for targets that only support masked store with a
+ /// constant mask.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const;
/// Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace, bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override {
+ unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &Op = MI->getOperand(OpNum);
+ if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+ O << "_";
+ else
+ printOperand(MI, OpNum, O);
+}
+
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ if (Mask->getConstantOperandVal(I) == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def RegOrSink : Operand<Any> {
+ let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
def AtomicCode : Operand<i32> {
let PrintMethod = "printAtomicCode";
}
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
"\t[$addr], {{$src1, $src2}};">;
def _v4 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
if support_v8 then
def _v8 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
- O:$src5, O:$src6, O:$src7, O:$src8,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+ RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const {
+
+ if (!IsMaskConstant)
+ return false;
+
+ // We currently only support this feature for 256-bit vectors, so the
+ // alignment must be at least 32
+ if (Alignment < 32)
+ return false;
+
+ if (!ST->has256BitVectorLoadStore(AddrSpace))
+ return false;
+
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+
+ auto *ScalarTy = VTy->getScalarType();
+ if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+ return true;
+
+ return false;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const override;
+
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getArgOperand(0)->getType(),
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
cast<PointerType>(CI->getArgOperand(1)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(3))))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<9>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT: ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT: and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT: setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT: ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT: and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT: setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT: ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT: ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT: ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT: not.pred %p5, %p1;
+; CHECK-NEXT: @%p5 bra $L__BB0_2;
+; CHECK-NEXT: // %bb.1: // %cond.store
+; CHECK-NEXT: st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT: $L__BB0_2: // %else
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT: not.pred %p6, %p2;
+; CHECK-NEXT: @%p6 bra $L__BB0_4;
+; CHECK-NEXT: // %bb.3: // %cond.store1
+; CHECK-NEXT: st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT: $L__BB0_4: // %else2
+; CHECK-NEXT: setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT: not.pred %p7, %p3;
+; CHECK-NEXT: @%p7 bra $L__BB0_6;
+; CHECK-NEXT: // %bb.5: // %cond.store3
+; CHECK-NEXT: st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT: $L__BB0_6: // %else4
+; CHECK-NEXT: not.pred %p8, %p4;
+; CHECK-NEXT: @%p8 bra $L__BB0_8;
+; CHECK-NEXT: // %bb.7: // %cond.store5
+; CHECK-NEXT: st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT: $L__BB0_8: // %else6
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..0935bf80b04be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,319 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+; this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr %a
+ tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr %a
+ tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x float>, ptr %a
+ tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x double>, ptr %a
+ tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x float>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x double>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT: st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT: st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+ ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK: {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT: st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT: st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT: ret;
+ %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+ %load05 = extractelement <8 x i32> %1, i32 0
+ %load16 = extractelement <8 x i32> %1, i32 1
+ %load38 = extractelement <8 x i32> %1, i32 3
+ %load49 = extractelement <8 x i32> %1, i32 4
+ %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+ %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+ %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+ %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+ %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+ %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+ %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+ %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+ call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)
>From 861d8126795c859b71c2e27ee8c28f9c2c450503 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:52:13 +0000
Subject: [PATCH 2/2] Clang format
---
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 3 ++-
llvm/lib/Analysis/TargetTransformInfo.cpp | 6 ++++--
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 3 ++-
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 4 ++--
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h | 3 ++-
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp | 5 +++--
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h | 4 ++--
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 3 ++-
llvm/lib/Target/VE/VETargetTransformInfo.h | 3 ++-
llvm/lib/Target/X86/X86TargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/X86/X86TargetTransformInfo.h | 3 ++-
12 files changed, 27 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 33705e1dd5f98..267ada0e3c76a 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,8 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 838712e55d0dd..b37f7969fc792 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+ IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e40631d88748c..669d9e2ae1dad 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ee4f72552d90d..3c5b2195d8dcc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -189,8 +189,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
unsigned AddressSpace) const override;
- bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
+ bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index c989bf77a9d51..74df572ef1521 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,8 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e2674bb9cdad7..195cc616347b0 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,8 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const override;
+ unsigned AddressSpace,
+ bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 88b13cb38d67b..045e9d2b099f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -598,11 +598,12 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
}
bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddrSpace, bool IsMaskConstant) const {
+ unsigned AddrSpace,
+ bool IsMaskConstant) const {
if (!IsMaskConstant)
return false;
-
+
// We currently only support this feature for 256-bit vectors, so the
// alignment must be at least 32
if (Alignment < 32)
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e5500966fe10..9b353be7768c6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,8 +181,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddrSpace, bool IsMaskConstant) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+ bool IsMaskConstant) const override;
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 80f10eb29bca4..8838f881ca3db 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 4971d9148b747..3f33760a5ead9 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,8 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index b16a2a593df03..851009f89e196 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 7f6ff65d427ed..c7c3d93600878 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,8 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant = false) const override;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
More information about the llvm-commits
mailing list