[llvm] [NVPTX] Lower LLVM masked vector loads and stores to PTX (PR #159387)
Drew Kersnar via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 27 14:22:51 PDT 2025
https://github.com/dakersnar updated https://github.com/llvm/llvm-project/pull/159387
>From 069c336c4b510e3b59a98a4b3e3386cd244ea77a Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:30:38 +0000
Subject: [PATCH 01/11] [NVPTX] Lower LLVM masked vector stores to PTX using
the new sink symbol syntax
---
.../llvm/Analysis/TargetTransformInfo.h | 8 +-
.../llvm/Analysis/TargetTransformInfoImpl.h | 2 +-
llvm/lib/Analysis/TargetTransformInfo.cpp | 4 +-
.../AArch64/AArch64TargetTransformInfo.h | 2 +-
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 2 +-
.../Hexagon/HexagonTargetTransformInfo.cpp | 2 +-
.../Hexagon/HexagonTargetTransformInfo.h | 2 +-
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 10 +
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.h | 2 +
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 89 ++++-
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 10 +-
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 26 ++
.../Target/NVPTX/NVPTXTargetTransformInfo.h | 3 +
.../Target/RISCV/RISCVTargetTransformInfo.h | 2 +-
llvm/lib/Target/VE/VETargetTransformInfo.h | 2 +-
.../lib/Target/X86/X86TargetTransformInfo.cpp | 2 +-
llvm/lib/Target/X86/X86TargetTransformInfo.h | 2 +-
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 3 +-
.../NVPTX/masked-store-variable-mask.ll | 56 +++
.../CodeGen/NVPTX/masked-store-vectors-256.ll | 319 ++++++++++++++++++
20 files changed, 530 insertions(+), 18 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5d3b233ed6b6a..f0d30913e1a52 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -813,9 +813,13 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
- /// Return true if the target supports masked store.
+ /// Return true if the target supports masked store. A value of false for
+ /// IsMaskConstant indicates that the mask could either be variable or
+ /// constant. This is for targets that only support masked store with a
+ /// constant mask.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const;
/// Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4cd607c0d0c8d..06df6ca58b122 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index bf62623099a97..9310350b48585 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,8 +468,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace, bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override {
+ unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 77913f27838e2..c8717876157c5 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &Op = MI->getOperand(OpNum);
+ if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+ O << "_";
+ else
+ printOperand(MI, OpNum, O);
+}
+
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 2f1a7ad2d401f..e52854c3b5627 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -769,7 +769,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -3181,6 +3181,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ if (Mask->getConstantOperandVal(I) == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -3217,6 +3298,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index dfde0cca0f00c..a0d5b09253c32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1575,6 +1575,10 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def RegOrSink : Operand<Any> {
+ let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
def AtomicCode : Operand<i32> {
let PrintMethod = "printAtomicCode";
}
@@ -1881,7 +1885,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
"\t[$addr], {{$src1, $src2}};">;
def _v4 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1889,8 +1893,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
if support_v8 then
def _v8 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
- O:$src5, O:$src6, O:$src7, O:$src8,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+ RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 4029e143ae2a4..cd90ae1601e1a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -592,6 +592,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const {
+
+ if (!IsMaskConstant)
+ return false;
+
+ // We currently only support this feature for 256-bit vectors, so the
+ // alignment must be at least 32
+ if (Alignment < 32)
+ return false;
+
+ if (!ST->has256BitVectorLoadStore(AddrSpace))
+ return false;
+
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+
+ auto *ScalarTy = VTy->getScalarType();
+ if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+ return true;
+
+ return false;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 78eb751cf3c2e..f06b94ba6ce80 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const override;
+
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 6886e8964e29e..2d4f8d7d6a6ec 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -290,7 +290,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 146e7d1047dd0..2cb334d8d6952 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1132,7 +1132,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getArgOperand(0)->getType(),
CI->getParamAlign(1).valueOrOne(),
cast<PointerType>(CI->getArgOperand(1)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(3))))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<9>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT: ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT: and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT: setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT: ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT: and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT: setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT: ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT: ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT: ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT: not.pred %p5, %p1;
+; CHECK-NEXT: @%p5 bra $L__BB0_2;
+; CHECK-NEXT: // %bb.1: // %cond.store
+; CHECK-NEXT: st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT: $L__BB0_2: // %else
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT: not.pred %p6, %p2;
+; CHECK-NEXT: @%p6 bra $L__BB0_4;
+; CHECK-NEXT: // %bb.3: // %cond.store1
+; CHECK-NEXT: st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT: $L__BB0_4: // %else2
+; CHECK-NEXT: setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT: not.pred %p7, %p3;
+; CHECK-NEXT: @%p7 bra $L__BB0_6;
+; CHECK-NEXT: // %bb.5: // %cond.store3
+; CHECK-NEXT: st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT: $L__BB0_6: // %else4
+; CHECK-NEXT: not.pred %p8, %p4;
+; CHECK-NEXT: @%p8 bra $L__BB0_8;
+; CHECK-NEXT: // %bb.7: // %cond.store5
+; CHECK-NEXT: st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT: $L__BB0_8: // %else6
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..0935bf80b04be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,319 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+; this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr %a
+ tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr %a
+ tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x float>, ptr %a
+ tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x double>, ptr %a
+ tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x float>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x double>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT: st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT: st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+ ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK: {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT: st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT: st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT: ret;
+ %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+ %load05 = extractelement <8 x i32> %1, i32 0
+ %load16 = extractelement <8 x i32> %1, i32 1
+ %load38 = extractelement <8 x i32> %1, i32 3
+ %load49 = extractelement <8 x i32> %1, i32 4
+ %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+ %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+ %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+ %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+ %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+ %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+ %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+ %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+ call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)
>From 68ab3381409e56182a65a8a0f65af03b59704d91 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:52:13 +0000
Subject: [PATCH 02/11] Clang format
---
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 3 ++-
llvm/lib/Analysis/TargetTransformInfo.cpp | 6 ++++--
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 3 ++-
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 4 ++--
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h | 3 ++-
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp | 5 +++--
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h | 4 ++--
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 3 ++-
llvm/lib/Target/VE/VETargetTransformInfo.h | 3 ++-
llvm/lib/Target/X86/X86TargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/X86/X86TargetTransformInfo.h | 3 ++-
12 files changed, 27 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 06df6ca58b122..3d8b6d918d7fa 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,8 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 9310350b48585..7d02ea8ca5f96 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,8 +468,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+ IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e40631d88748c..669d9e2ae1dad 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ee4f72552d90d..3c5b2195d8dcc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -189,8 +189,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
unsigned AddressSpace) const override;
- bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
+ bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index c989bf77a9d51..74df572ef1521 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,8 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e2674bb9cdad7..195cc616347b0 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,8 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const override;
+ unsigned AddressSpace,
+ bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index cd90ae1601e1a..74c993ce6d7ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -593,11 +593,12 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
}
bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddrSpace, bool IsMaskConstant) const {
+ unsigned AddrSpace,
+ bool IsMaskConstant) const {
if (!IsMaskConstant)
return false;
-
+
// We currently only support this feature for 256-bit vectors, so the
// alignment must be at least 32
if (Alignment < 32)
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index f06b94ba6ce80..80cf9772f76c5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,8 +181,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddrSpace, bool IsMaskConstant) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+ bool IsMaskConstant) const override;
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 2d4f8d7d6a6ec..8dfa748963a38 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -290,7 +290,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 4971d9148b747..3f33760a5ead9 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,8 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+ unsigned /*AddressSpace*/,
+ bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index b16a2a593df03..851009f89e196 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant) const {
+ unsigned AddressSpace,
+ bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 7f6ff65d427ed..c7c3d93600878 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,8 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace, bool IsMaskConstant = false) const override;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
>From 8e7140ac245dd9d67e73686c8ac4109ac3b06299 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Fri, 17 Oct 2025 21:35:04 +0000
Subject: [PATCH 03/11] Update TTI to add enum value representing MaskKind to
isLegalMaskedLoad/Store
---
.../llvm/Analysis/TargetTransformInfo.h | 22 +++++++++++--------
.../llvm/Analysis/TargetTransformInfoImpl.h | 5 +++--
llvm/lib/Analysis/TargetTransformInfo.cpp | 10 +++++----
.../AArch64/AArch64TargetTransformInfo.h | 5 +++--
.../lib/Target/ARM/ARMTargetTransformInfo.cpp | 3 ++-
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 8 +++----
.../Hexagon/HexagonTargetTransformInfo.cpp | 5 +++--
.../Hexagon/HexagonTargetTransformInfo.h | 6 ++---
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 16 +++++++++++---
.../Target/NVPTX/NVPTXTargetTransformInfo.h | 5 ++++-
.../Target/RISCV/RISCVTargetTransformInfo.h | 5 +++--
llvm/lib/Target/VE/VETargetTransformInfo.h | 11 +++++-----
.../lib/Target/X86/X86TargetTransformInfo.cpp | 5 +++--
llvm/lib/Target/X86/X86TargetTransformInfo.h | 13 ++++++-----
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 9 ++++++--
15 files changed, 81 insertions(+), 47 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f0d30913e1a52..c61c1a9a9c622 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -813,16 +813,20 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
- /// Return true if the target supports masked store. A value of false for
- /// IsMaskConstant indicates that the mask could either be variable or
- /// constant. This is for targets that only support masked store with a
- /// constant mask.
- LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace,
- bool IsMaskConstant = false) const;
+ /// Some targets only support masked load/store with a constant mask.
+ enum MaskKind {
+ VariableOrConstantMask,
+ ConstantMask,
+ };
+
+ /// Return true if the target supports masked store.
+ LLVM_ABI bool
+ isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+ MaskKind MaskKind = VariableOrConstantMask) const;
/// Return true if the target supports masked load.
- LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ LLVM_ABI bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ MaskKind MaskKind = VariableOrConstantMask) const;
/// Return true if the target supports nontemporal store.
LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d8b6d918d7fa..653ae750dec52 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -310,12 +310,13 @@ class TargetTransformInfoImplBase {
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace,
- bool IsMaskConstant) const {
+ TTI::MaskKind MaskKind) const {
return false;
}
virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7d02ea8ca5f96..017e9b8d9fe00 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -469,14 +469,16 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace,
- bool IsMaskConstant) const {
+ TTI::MaskKind MaskKind) const {
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
- IsMaskConstant);
+ MaskKind);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
+ return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
+ MaskKind);
}
bool TargetTransformInfo::isLegalNTStore(Type *DataType,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 669d9e2ae1dad..4fda311738a90 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -316,13 +316,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned /*AddressSpace*/,
- bool /*IsMaskConstant*/) const override {
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 9b250e6cac3ab..7ee7e9a953f5c 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
}
bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const {
if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
return false;
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 3c5b2195d8dcc..401f5ea044dab 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,12 +186,12 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
bool isProfitableLSRChainElement(Instruction *I) const override;
- bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override;
+ bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const override;
bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
- bool /*IsMaskConstant*/) const override {
- return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
+ TTI::MaskKind MaskKind) const override {
+ return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
}
bool forceScalarizeMaskedGather(VectorType *VTy,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 74df572ef1521..da006ff037eeb 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -342,14 +342,15 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
unsigned /*AddressSpace*/,
- bool /*IsMaskConstant*/) const {
+ TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
}
bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 195cc616347b0..1ca7408d75126 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -167,9 +167,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace,
- bool IsMaskConstant) const override;
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ TTI::MaskKind MaskKind) const override;
+ bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const override;
/// @}
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 74c993ce6d7ea..ee08d941fb5ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -594,9 +594,8 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
unsigned AddrSpace,
- bool IsMaskConstant) const {
-
- if (!IsMaskConstant)
+ TTI::MaskKind MaskKind) const {
+ if (MaskKind != TTI::MaskKind::ConstantMask)
return false;
// We currently only support this feature for 256-bit vectors, so the
@@ -619,6 +618,17 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
return false;
}
+bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+ unsigned /*AddrSpace*/,
+ TTI::MaskKind MaskKind) const {
+ if (MaskKind != TTI::MaskKind::ConstantMask)
+ return false;
+
+ if (Alignment < DL.getTypeStoreSize(DataTy))
+ return false;
+ return true;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 80cf9772f76c5..d7f4e1da4073b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -182,7 +182,10 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
Intrinsic::ID IID) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
- bool IsMaskConstant) const override;
+ TTI::MaskKind MaskKind) const override;
+
+ bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
+ TTI::MaskKind MaskKind) const override;
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8dfa748963a38..7988a6c35c768 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -286,12 +286,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
}
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned /*AddressSpace*/,
- bool /*IsMaskConstant*/) const override {
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 3f33760a5ead9..eed3832c9f1fb 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -134,13 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
}
// Load & Store {
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+ TargetTransformInfo::MaskKind /*MaskKind*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/,
- bool /*IsMaskConstant*/) const override {
+ bool isLegalMaskedStore(
+ Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+ TargetTransformInfo::MaskKind /*MaskKind*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 851009f89e196..60f148f9f747e 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6317,7 +6317,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
}
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
@@ -6331,7 +6332,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
unsigned AddressSpace,
- bool IsMaskConstant) const {
+ TTI::MaskKind MaskKind) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index c7c3d93600878..02b642f599270 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -268,11 +268,14 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2) const override;
bool canMacroFuseCmp() const override;
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace,
- bool IsMaskConstant = false) const override;
+ bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
+ bool
+ isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 2cb334d8d6952..ff78dd172f38d 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
if (TTI.isLegalMaskedLoad(
CI->getType(), CI->getParamAlign(0).valueOrOne(),
cast<PointerType>(CI->getArgOperand(0)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(2))
+ ? TTI::MaskKind::ConstantMask
+ : TTI::MaskKind::VariableOrConstantMask))
return false;
scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
@@ -1133,7 +1136,9 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getParamAlign(1).valueOrOne(),
cast<PointerType>(CI->getArgOperand(1)->getType())
->getAddressSpace(),
- isConstantIntVector(CI->getArgOperand(3))))
+ isConstantIntVector(CI->getArgOperand(3))
+ ? TTI::MaskKind::ConstantMask
+ : TTI::MaskKind::VariableOrConstantMask))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
>From 6ec5661345e45c43421beafec7c0ab09ade9e5b4 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Fri, 17 Oct 2025 23:53:27 +0000
Subject: [PATCH 04/11] Add masked load lowering support
---
.../SelectionDAG/LegalizeVectorTypes.cpp | 10 +-
.../SelectionDAG/SelectionDAGBuilder.cpp | 2 +
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 10 +
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.h | 1 +
llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp | 2 +-
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 33 +-
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 112 +++++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 +
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 24 +-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 17 +-
.../Target/NVPTX/NVPTXReplaceImageHandles.cpp | 4 +-
.../Target/NVPTX/NVPTXTagInvariantLoads.cpp | 25 +-
.../floating-point-immediate-operands.mir | 8 +-
.../test/CodeGen/NVPTX/masked-load-vectors.ll | 367 ++++++++++++++++++
llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir | 4 +-
15 files changed, 584 insertions(+), 37 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 3b5f83f7c089a..6bb564ee56018 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2449,6 +2449,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
SDValue PassThru = MLD->getPassThru();
Align Alignment = MLD->getBaseAlign();
ISD::LoadExtType ExtType = MLD->getExtensionType();
+ MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
// Split Mask operand
SDValue MaskLo, MaskHi;
@@ -2474,9 +2475,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
- MLD->getPointerInfo(), MachineMemOperand::MOLoad,
- LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
- MLD->getRanges());
+ MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
+ Alignment, MLD->getAAInfo(), MLD->getRanges());
Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
MMO, MLD->getAddressingMode(), ExtType,
@@ -2499,8 +2499,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
LoMemVT.getStoreSize().getFixedValue());
MMO = DAG.getMachineFunction().getMachineMemOperand(
- MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
- Alignment, MLD->getAAInfo(), MLD->getRanges());
+ MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
+ MLD->getAAInfo(), MLD->getRanges());
Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
HiMemVT, MMO, MLD->getAddressingMode(), ExtType,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 20a0efd3afa1c..4cc111e3941ed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5010,6 +5010,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
auto MMOFlags = MachineMemOperand::MOLoad;
if (I.hasMetadata(LLVMContext::MD_nontemporal))
MMOFlags |= MachineMemOperand::MONonTemporal;
+ if (I.hasMetadata(LLVMContext::MD_invariant_load))
+ MMOFlags |= MachineMemOperand::MOInvariant;
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index c8717876157c5..8ca3cb46b5455 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
+ raw_ostream &O) {
+ auto &Op = MI->getOperand(OpNum);
+ assert(Op.isImm() && "Invalid operand");
+ uint32_t Imm = (uint32_t)Op.getImm();
+ if (Imm != UINT32_MAX) {
+ O << ".pragma \"used_bytes_mask " << Imm << "\";\n\t";
+ }
+}
+
void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
raw_ostream &O,
const char *Modifier) {
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index d373668aa591f..89137a954d2d8 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index a3496090def3c..c8b53571c1e59 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());
- constexpr unsigned LDInstBasePtrOpIdx = 5;
+ constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 7e7ee754c250d..8e0399b493a24 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
switch (N->getOpcode()) {
case ISD::LOAD:
case ISD::ATOMIC_LOAD:
+ case NVPTXISD::MLoadV1:
if (tryLoad(N))
return;
break;
@@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;
+ uint32_t UsedBytesMask;
+ switch (N->getOpcode()) {
+ case ISD::LOAD:
+ case ISD::ATOMIC_LOAD:
+ UsedBytesMask = UINT32_MAX;
+ break;
+ case NVPTXISD::MLoadV1:
+ UsedBytesMask = N->getConstantOperandVal(N->getNumOperands() - 2);
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && "Invalid width for load");
@@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
@@ -1204,6 +1219,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
: NVPTX::PTXLdStInstCode::Untyped;
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
+ const uint32_t UsedBytesMask =
+ N->getConstantOperandVal(N->getNumOperands() - 2);
assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
@@ -1213,6 +1230,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
@@ -1250,10 +1268,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
SDLoc DL(LD);
unsigned ExtensionType;
+ uint32_t UsedBytesMask;
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
ExtensionType = Load->getExtensionType();
+ UsedBytesMask = UINT32_MAX;
} else {
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+ UsedBytesMask = LD->getConstantOperandVal(LD->getNumOperands() - 2);
}
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
@@ -1265,8 +1286,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
ExtensionType != ISD::NON_EXTLOAD));
const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
- SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
- Offset, LD->getChain()};
+ SDValue Ops[] = {getI32Imm(FromType, DL),
+ getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
+ Base,
+ Offset,
+ LD->getChain()};
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
std::optional<unsigned> Opcode;
@@ -1277,6 +1302,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
+ case NVPTXISD::MLoadV1:
+ Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
+ NVPTX::LD_GLOBAL_NC_i64);
+ break;
case NVPTXISD::LoadV2:
Opcode =
pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e52854c3b5627..cc23d1159932f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -769,7 +769,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
+ Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -1130,6 +1131,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::LoadV2)
MAKE_CASE(NVPTXISD::LoadV4)
MAKE_CASE(NVPTXISD::LoadV8)
+ MAKE_CASE(NVPTXISD::MLoadV1)
MAKE_CASE(NVPTXISD::LDUV2)
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
@@ -3306,6 +3308,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
}
case ISD::LOAD:
return LowerLOAD(Op, DAG);
+ case ISD::MLOAD:
+ return LowerMLOAD(Op, DAG);
case ISD::SHL_PARTS:
return LowerShiftLeftParts(Op, DAG);
case ISD::SRA_PARTS:
@@ -3497,10 +3501,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
+static std::tuple<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+ SDValue Chain = N->getOperand(0);
+ SDValue BasePtr = N->getOperand(1);
+ SDValue Mask = N->getOperand(3);
+ SDValue Passthru = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ResVT = N->getValueType(0);
+ assert(ResVT.isVector() && "Masked vector load must have vector type");
+ // While we only expect poison passthru vectors as an input to the backend,
+ // when the legalization framework splits a poison vector in half, it creates
+ // two undef vectors, so we can technically expect those too.
+ assert((Passthru.getOpcode() == ISD::POISON ||
+ Passthru.getOpcode() == ISD::UNDEF) &&
+ "Passthru operand expected to be poison or undef");
+
+ // Extract the mask and convert it to a uint32_t representing the used bytes
+ // of the entire vector load
+ uint32_t UsedBytesMask = 0;
+ uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+ assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+ uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+ uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+ for (unsigned I :
+ llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ // We technically only want to do this shift for every iteration *but* the
+ // first, but in the first iteration NewMask is 0, so this shift is a
+ // no-op.
+ UsedBytesMask <<= ElementSizeInBytes;
+
+ if (Mask->getConstantOperandVal(I) != 0)
+ UsedBytesMask |= ElementMask;
+ }
+
+ assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
+ "Unexpected masked load with elements masked all on or all off");
+
+ // Create a new load sd node to be handled normally by ReplaceLoadVector.
+ MemSDNode *NewLD = cast<MemSDNode>(
+ DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
+
+ return {NewLD, UsedBytesMask};
+}
+
/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
static std::optional<std::pair<SDValue, SDValue>>
replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
- LoadSDNode *LD = cast<LoadSDNode>(N);
+ MemSDNode *LD = cast<MemSDNode>(N);
const EVT ResVT = LD->getValueType(0);
const EVT MemVT = LD->getMemoryVT();
@@ -3527,6 +3579,14 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
return std::nullopt;
}
+ // If we have a masked load, convert it to a normal load now
+ std::optional<uint32_t> UsedBytesMask = std::nullopt;
+ if (LD->getOpcode() == ISD::MLOAD) {
+ auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
+ LD = std::get<0>(Result);
+ UsedBytesMask = std::get<1>(Result);
+ }
+
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
// loaded type to i16 and propagate the "real" type as the memory type.
@@ -3555,9 +3615,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
// Copy regular operands
SmallVector<SDValue, 8> OtherOps(LD->ops());
+ OtherOps.push_back(
+ DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
+
// The select routine does not have access to the LoadSDNode instance, so
// pass along the extension information
- OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
LD->getMemOperand());
@@ -3645,6 +3709,43 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
llvm_unreachable("Unexpected custom lowering for load");
}
+SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
+ // masked loads of these types and have to handle them here.
+ // v2f32 also needs to be handled here if the subtarget has f32x2
+ // instructions, making it legal.
+ //
+ // Note: misaligned masked loads should never reach this point
+ // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
+ // will validate alignment. Therefore, we do not need to special case handle
+ // them here.
+ EVT VT = Op.getValueType();
+ if (NVPTX::isPackedVectorTy(VT) &&
+ (VT != MVT::v2f32 || STI.hasF32x2Instructions())) {
+ auto Result =
+ convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
+ MemSDNode *LD = std::get<0>(Result);
+ uint32_t UsedBytesMask = std::get<1>(Result);
+
+ SDLoc DL(LD);
+
+ // Copy regular operands
+ SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+ OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
+
+ // The select routine does not have access to the LoadSDNode instance, so
+ // pass along the extension information
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ SDValue NewLD = DAG.getMemIntrinsicNode(
+ NVPTXISD::MLoadV1, DL, LD->getVTList(), OtherOps, LD->getMemoryVT(),
+ LD->getMemOperand());
+ return NewLD;
+ }
+ return SDValue();
+}
+
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
const NVPTXSubtarget &STI) {
MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -5555,9 +5656,13 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
// here.
Opcode = NVPTXISD::LoadV2;
+ // append a "full" used bytes mask operand right before the extension type
+ // operand, signifying that all bytes are used.
+ Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
+ // TODO do we need to support MLoadV1 here?
case NVPTXISD::LoadV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadV4;
@@ -6793,6 +6898,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
+ case ISD::MLOAD:
replaceLoadVector(N, DAG, Results, STI);
return;
case ISD::INTRINSIC_W_CHAIN:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63fa0bb9159ff..89bf0c290292a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -99,6 +99,7 @@ enum NodeType : unsigned {
LoadV2,
LoadV4,
LoadV8,
+ MLoadV1,
LDUV2, // LDU.v2
LDUV4, // LDU.v4
StoreV2,
@@ -349,6 +350,7 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a0d5b09253c32..24a40ab627a30 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1575,6 +1575,10 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def UsedBytesMask : Operand<i32> {
+ let PrintMethod = "printUsedBytesMaskPragma";
+}
+
def RegOrSink : Operand<Any> {
let PrintMethod = "printRegisterOrSinkSymbol";
}
@@ -1817,8 +1821,10 @@ def Callseq_End :
class LD<NVPTXRegClass regclass>
: NVPTXInst<
(outs regclass:$dst),
- (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
- i32imm:$fromWidth, ADDR:$addr),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$fromWidth "
"\t$dst, [$addr];">;
@@ -1850,21 +1856,27 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
- AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2}}, [$addr];">;
def _v4 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
- AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];">;
if support_v8 then
def _v8 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
- (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
- i32imm:$fromWidth, ADDR:$addr),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
"[$addr];">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 22cf3a7eef2c1..65c951e7aa06d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2351,7 +2351,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
// during the lifetime of the kernel.
class LDG_G<NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ : NVPTXInst<(outs regclass:$result),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth,
+ UsedBytesMask:$usedBytes, ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">;
def LD_GLOBAL_NC_i16 : LDG_G<B16>;
@@ -2363,19 +2366,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>;
// Elementized vector ldg
class VLDG_G_ELE_V2<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">;
class VLDG_G_ELE_V4<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">;
class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index 320c0fb6950a7..4bbf49f93f43b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -1808,8 +1808,8 @@ bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
// For CUDA, we preserve the param loads coming from function arguments
return false;
- assert(TexHandleDef.getOperand(6).isSymbol() && "Load is not a symbol!");
- StringRef Sym = TexHandleDef.getOperand(6).getSymbolName();
+ assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
+ StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
InstrsToRemove.insert(&TexHandleDef);
Op.ChangeToES(Sym.data());
MFI->getImageHandleSymbolIndex(Sym);
diff --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
index a4aff44ac04f6..6fa518e8d409b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
@@ -27,13 +27,14 @@
using namespace llvm;
-static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
+static bool isInvariantLoad(const Instruction *I, const Value *Ptr,
+ const bool IsKernelFn) {
// Don't bother with non-global loads
- if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
+ if (Ptr->getType()->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
return false;
// If the load is already marked as invariant, we don't need to do anything
- if (LI->getMetadata(LLVMContext::MD_invariant_load))
+ if (I->getMetadata(LLVMContext::MD_invariant_load))
return false;
// We use getUnderlyingObjects() here instead of getUnderlyingObject()
@@ -41,7 +42,7 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
// not. We need to look through phi nodes to handle pointer induction
// variables.
SmallVector<const Value *, 8> Objs;
- getUnderlyingObjects(LI->getPointerOperand(), Objs);
+ getUnderlyingObjects(Ptr, Objs);
return all_of(Objs, [&](const Value *V) {
if (const auto *A = dyn_cast<const Argument>(V))
@@ -53,9 +54,9 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
});
}
-static void markLoadsAsInvariant(LoadInst *LI) {
- LI->setMetadata(LLVMContext::MD_invariant_load,
- MDNode::get(LI->getContext(), {}));
+static void markLoadsAsInvariant(Instruction *I) {
+ I->setMetadata(LLVMContext::MD_invariant_load,
+ MDNode::get(I->getContext(), {}));
}
static bool tagInvariantLoads(Function &F) {
@@ -64,10 +65,18 @@ static bool tagInvariantLoads(Function &F) {
bool Changed = false;
for (auto &I : instructions(F)) {
if (auto *LI = dyn_cast<LoadInst>(&I)) {
- if (isInvariantLoad(LI, IsKernelFn)) {
+ if (isInvariantLoad(LI, LI->getPointerOperand(), IsKernelFn)) {
markLoadsAsInvariant(LI);
Changed = true;
}
+ if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+ if (II->getIntrinsicID() == Intrinsic::masked_load) {
+ if (isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
+ markLoadsAsInvariant(II);
+ Changed = true;
+ }
+ }
+ }
}
}
return Changed;
diff --git a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
index e3b072549bc04..3158916a3195c 100644
--- a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
+++ b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
@@ -40,9 +40,9 @@ registers:
- { id: 7, class: b32 }
body: |
bb.0.entry:
- %0 = LD_i32 0, 0, 4, 2, 32, &test_param_0, 0
+ %0 = LD_i32 0, 0, 4, 2, 32, -1, &test_param_0, 0
%1 = CVT_f64_f32 %0, 0
- %2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0
+ %2 = LD_i32 0, 0, 4, 0, 32, -1, &test_param_1, 0
; CHECK: %3:b64 = FADD_rnf64ri %1, double 3.250000e+00
%3 = FADD_rnf64ri %1, double 3.250000e+00
%4 = CVT_f32_f64 %3, 5
@@ -66,9 +66,9 @@ registers:
- { id: 7, class: b32 }
body: |
bb.0.entry:
- %0 = LD_i32 0, 0, 4, 2, 32, &test2_param_0, 0
+ %0 = LD_i32 0, 0, 4, 2, 32, -1, &test2_param_0, 0
%1 = CVT_f64_f32 %0, 0
- %2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0
+ %2 = LD_i32 0, 0, 4, 0, 32, -1, &test2_param_1, 0
; CHECK: %3:b64 = FADD_rnf64ri %1, double 0x7FF8000000000000
%3 = FADD_rnf64ri %1, double 0x7FF8000000000000
%4 = CVT_f32_f64 %3, 5
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
new file mode 100644
index 0000000000000..2e58aae6ad478
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -0,0 +1,367 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+
+; Different architectures are tested in this file for the following reasons:
+; - SM90 does not have 256-bit load/store instructions
+; - SM90 does not have masked store instructions
+; - SM90 does not support packed f32x2 instructions
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+; Masked stores are only supported for 32 byte element types,
+; Masked stores are only supported for 32-bit element types,
+; while masked loads are supported for all element types.
+define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_16xi16(
+; SM90: {
+; SM90-NEXT: .reg .b16 %rs<7>;
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4;
+; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r7;
+; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r5;
+; SM90-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM90-NEXT: st.global.b16 [%rd2], %rs5;
+; SM90-NEXT: st.global.b16 [%rd2+2], %rs6;
+; SM90-NEXT: st.global.b16 [%rd2+8], %rs3;
+; SM90-NEXT: st.global.b16 [%rd2+10], %rs4;
+; SM90-NEXT: st.global.b16 [%rd2+28], %rs1;
+; SM90-NEXT: st.global.b16 [%rd2+30], %rs2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_16xi16(
+; SM100: {
+; SM100-NEXT: .reg .b16 %rs<7>;
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: mov.b32 {%rs1, %rs2}, %r8;
+; SM100-NEXT: mov.b32 {%rs3, %rs4}, %r3;
+; SM100-NEXT: mov.b32 {%rs5, %rs6}, %r1;
+; SM100-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM100-NEXT: st.global.b16 [%rd2], %rs5;
+; SM100-NEXT: st.global.b16 [%rd2+2], %rs6;
+; SM100-NEXT: st.global.b16 [%rd2+8], %rs3;
+; SM100-NEXT: st.global.b16 [%rd2+10], %rs4;
+; SM100-NEXT: st.global.b16 [%rd2+28], %rs1;
+; SM100-NEXT: st.global.b16 [%rd2+30], %rs2;
+; SM100-NEXT: ret;
+ %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) %a, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+ tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) %b, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+ ret void
+}
+
+define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_8xi32_no_align_param_0];
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_8xi32_no_align_param_1];
+; CHECK-NEXT: ld.global.b32 %r2, [%rd1+8];
+; CHECK-NEXT: ld.global.b32 %r3, [%rd1+28];
+; CHECK-NEXT: st.global.b32 [%rd2], %r1;
+; CHECK-NEXT: st.global.b32 [%rd2+8], %r2;
+; CHECK-NEXT: st.global.b32 [%rd2+28], %r3;
+; CHECK-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+
+define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_invariant(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32_invariant(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 3";
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_param_1];
+; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_invariant(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_invariant_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 3";
+; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_invariant_param_1];
+; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_no_align_param_0];
+; CHECK-NEXT: ld.global.b16 %rs1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 2, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 5";
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_param_1];
+; CHECK-NEXT: st.global.b8 [%rd2], %r1;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_invariant(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_invariant_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 5";
+; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_invariant_param_1];
+; CHECK-NEXT: st.global.b8 [%rd2], %r1;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_no_align_param_0];
+; CHECK-NEXT: ld.global.b8 %rs1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_no_align_param_1];
+; CHECK-NEXT: ld.global.b8 %rs2, [%rd1+2];
+; CHECK-NEXT: st.global.b8 [%rd2], %rs1;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %rs2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 2, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; In sm100+, we pack 2xf32 loads into a single b64 load while lowering
+define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<3>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 15";
+; SM90-NEXT: ld.global.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_2xf32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<2>;
+; SM100-NEXT: .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 15";
+; SM100-NEXT: ld.global.b64 %rd2, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_param_1];
+; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT: st.global.b32 [%rd3], %r1;
+; SM100-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32_invariant(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<3>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 15";
+; SM90-NEXT: ld.global.nc.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_invariant_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_2xf32_invariant(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<2>;
+; SM100-NEXT: .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 15";
+; SM100-NEXT: ld.global.nc.b64 %rd2, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_invariant_param_1];
+; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT: st.global.b32 [%rd3], %r1;
+; SM100-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xf32_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xf32_no_align_param_0];
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
+; CHECK-NEXT: st.global.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), i32, <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), i32, <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), i32, <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), i32, <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), i32, <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), i32, <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), i32, <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), i32, <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), i32, <2 x i1>)
+!0 = !{}
diff --git a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
index dfc84177fb0e6..a84b7fcd33836 100644
--- a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
+++ b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
@@ -77,7 +77,7 @@ constants: []
machineFunctionInfo: {}
body: |
bb.0:
- %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, &retval0, 0 :: (load (s128), addrspace 101)
+ %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s128), addrspace 101)
; CHECK-NOT: ProxyReg
%4:b32 = ProxyRegB32 killed %0
%5:b32 = ProxyRegB32 killed %1
@@ -86,7 +86,7 @@ body: |
; CHECK: STV_i32_v4 killed %0, killed %1, killed %2, killed %3
STV_i32_v4 killed %4, killed %5, killed %6, killed %7, 0, 0, 101, 32, &func_retval0, 0 :: (store (s128), addrspace 101)
- %8:b32 = LD_i32 0, 0, 101, 3, 32, &retval0, 0 :: (load (s32), addrspace 101)
+ %8:b32 = LD_i32 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s32), addrspace 101)
; CHECK-NOT: ProxyReg
%9:b32 = ProxyRegB32 killed %8
%10:b32 = ProxyRegB32 killed %9
>From c44f8a2d8c528230fd303d847d15c3335c87c9a4 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Sat, 18 Oct 2025 00:27:50 +0000
Subject: [PATCH 05/11] Fix typo
---
llvm/test/CodeGen/NVPTX/masked-load-vectors.ll | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
index 2e58aae6ad478..5b4136ec1307e 100644
--- a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -45,7 +45,6 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
ret void
}
-; Masked stores are only supported for 32 byte element types,
; Masked stores are only supported for 32-bit element types,
; while masked loads are supported for all element types.
define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
>From abbf0ba102a8ecd1334c4fbd88fd12e1463d5b55 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 16:39:12 +0000
Subject: [PATCH 06/11] Review feedback
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 30 ++++++++-----------
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 9 ++----
2 files changed, 15 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index cc23d1159932f..6a5ea30ea0a0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3230,10 +3230,9 @@ static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
assert(Mask.getValueType().getVectorNumElements() ==
ValVT.getVectorNumElements() &&
"Mask size must be the same as the vector size");
- for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
- assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
- "Mask elements must be constants");
- if (Mask->getConstantOperandVal(I) == 0) {
+ for (auto [I, Op] : enumerate(Mask->ops())) {
+ // Mask elements must be constants.
+ if (Op.getNode()->getAsZExtVal() == 0) {
// Append a sentinel register 0 to the Ops vector to represent a masked
// off element, this will be handled in tablegen
Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
@@ -3501,7 +3500,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
-static std::tuple<MemSDNode *, uint32_t>
+static std::pair<MemSDNode *, uint32_t>
convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
SDValue Chain = N->getOperand(0);
SDValue BasePtr = N->getOperand(1);
@@ -3526,16 +3525,14 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
- for (unsigned I :
- llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
- assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
- "Mask elements must be constants");
- // We technically only want to do this shift for every iteration *but* the
- // first, but in the first iteration NewMask is 0, so this shift is a
- // no-op.
+ for (SDValue Op : llvm::reverse(Mask->ops())) {
+ // We technically only want to do this shift for every
+ // iteration *but* the first, but in the first iteration NewMask is 0, so
+ // this shift is a no-op.
UsedBytesMask <<= ElementSizeInBytes;
- if (Mask->getConstantOperandVal(I) != 0)
+ // Mask elements must be constants.
+ if (Op->getAsZExtVal() != 0)
UsedBytesMask |= ElementMask;
}
@@ -3581,11 +3578,8 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
// If we have a masked load, convert it to a normal load now
std::optional<uint32_t> UsedBytesMask = std::nullopt;
- if (LD->getOpcode() == ISD::MLOAD) {
- auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
- LD = std::get<0>(Result);
- UsedBytesMask = std::get<1>(Result);
- }
+ if (LD->getOpcode() == ISD::MLOAD)
+ std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index ee08d941fb5ea..00bb129cbf801 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -610,12 +610,9 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
if (!VTy)
return false;
- auto *ScalarTy = VTy->getScalarType();
- if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
- (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
- return true;
-
- return false;
+ auto *ElemTy = VTy->getScalarType();
+ return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
}
bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
>From 8bf4ed6b245a6cff60a741ba1c62a3c884b20762 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 19:03:55 +0000
Subject: [PATCH 07/11] More review feedback
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 6a5ea30ea0a0a..8c86912dc95a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3525,7 +3525,7 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
- for (SDValue Op : llvm::reverse(Mask->ops())) {
+ for (SDValue Op : reverse(Mask->ops())) {
// We technically only want to do this shift for every
// iteration *but* the first, but in the first iteration NewMask is 0, so
// this shift is a no-op.
@@ -3714,8 +3714,7 @@ SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
// will validate alignment. Therefore, we do not need to special case handle
// them here.
EVT VT = Op.getValueType();
- if (NVPTX::isPackedVectorTy(VT) &&
- (VT != MVT::v2f32 || STI.hasF32x2Instructions())) {
+ if (NVPTX::isPackedVectorTy(VT)) {
auto Result =
convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
MemSDNode *LD = std::get<0>(Result);
@@ -5656,7 +5655,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
- // TODO do we need to support MLoadV1 here?
case NVPTXISD::LoadV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadV4;
>From 7315fbd119f469614f364e07411c1ff67183437f Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 19:24:59 +0000
Subject: [PATCH 08/11] Fix ARM TTI
---
llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 401f5ea044dab..9a940cb3453b7 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,11 +186,15 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
bool isProfitableLSRChainElement(Instruction *I) const override;
- bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
- TTI::MaskKind MaskKind) const override;
-
- bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
- TTI::MaskKind MaskKind) const override {
+ bool
+ isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
+
+ bool
+ isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
}
>From 95c335cf3dcf0ebb501521f9f00a9058e1e7fecf Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 22 Oct 2025 17:04:21 +0000
Subject: [PATCH 09/11] Make fixes based on recent TOT changes, adjust tests,
expand LoadV8 unpacking mov handling for v2i32 packed types
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 ++--
.../Scalar/ScalarizeMaskedMemIntrin.cpp | 4 +-
llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll | 16 ++---
.../NVPTX/machinelicm-no-preheader.mir | 12 ++--
.../test/CodeGen/NVPTX/masked-load-vectors.ll | 72 +++++++++----------
.../NVPTX/masked-store-variable-mask.ll | 4 +-
.../CodeGen/NVPTX/masked-store-vectors-256.ll | 47 ++++++------
7 files changed, 81 insertions(+), 86 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8c86912dc95a4..73b2e930b7b8e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5660,9 +5660,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Opcode = NVPTXISD::LoadV4;
break;
case NVPTXISD::LoadV4:
- // V8 is only supported for f32. Don't forget, we're not changing the load
- // size here. This is already a 256-bit load.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // load size here. This is already a 256-bit load.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
OldNumOutputs = 4;
Opcode = NVPTXISD::LoadV8;
@@ -5737,9 +5737,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
- // V8 is only supported for f32. Don't forget, we're not changing the store
- // size here. This is already a 256-bit store.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // store size here. This is already a 256-bit store.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
Opcode = NVPTXISD::StoreV8;
break;
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index ff78dd172f38d..b7b08ae61ec52 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1124,7 +1124,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getType(), CI->getParamAlign(0).valueOrOne(),
cast<PointerType>(CI->getArgOperand(0)->getType())
->getAddressSpace(),
- isConstantIntVector(CI->getArgOperand(2))
+ isConstantIntVector(CI->getArgOperand(1))
? TTI::MaskKind::ConstantMask
: TTI::MaskKind::VariableOrConstantMask))
return false;
@@ -1136,7 +1136,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getParamAlign(1).valueOrOne(),
cast<PointerType>(CI->getArgOperand(1)->getType())
->getAddressSpace(),
- isConstantIntVector(CI->getArgOperand(3))
+ isConstantIntVector(CI->getArgOperand(2))
? TTI::MaskKind::ConstantMask
: TTI::MaskKind::VariableOrConstantMask))
return false;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
index 3fac29f74125b..d219493d2b31b 100644
--- a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -346,19 +346,15 @@ define i32 @ld_global_v8i32(ptr addrspace(1) %ptr) {
; SM100-LABEL: ld_global_v8i32(
; SM100: {
; SM100-NEXT: .reg .b32 %r<16>;
-; SM100-NEXT: .reg .b64 %rd<6>;
+; SM100-NEXT: .reg .b64 %rd<2>;
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [ld_global_v8i32_param_0];
-; SM100-NEXT: ld.global.nc.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
-; SM100-NEXT: mov.b64 {%r1, %r2}, %rd5;
-; SM100-NEXT: mov.b64 {%r3, %r4}, %rd4;
-; SM100-NEXT: mov.b64 {%r5, %r6}, %rd3;
-; SM100-NEXT: mov.b64 {%r7, %r8}, %rd2;
-; SM100-NEXT: add.s32 %r9, %r7, %r8;
-; SM100-NEXT: add.s32 %r10, %r5, %r6;
-; SM100-NEXT: add.s32 %r11, %r3, %r4;
-; SM100-NEXT: add.s32 %r12, %r1, %r2;
+; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: add.s32 %r9, %r1, %r2;
+; SM100-NEXT: add.s32 %r10, %r3, %r4;
+; SM100-NEXT: add.s32 %r11, %r5, %r6;
+; SM100-NEXT: add.s32 %r12, %r7, %r8;
; SM100-NEXT: add.s32 %r13, %r9, %r10;
; SM100-NEXT: add.s32 %r14, %r11, %r12;
; SM100-NEXT: add.s32 %r15, %r13, %r14;
diff --git a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
index 0b2d85600a2ef..4be91dfc60c6a 100644
--- a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
+++ b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
@@ -26,10 +26,10 @@ body: |
; CHECK: bb.0.entry:
; CHECK-NEXT: successors: %bb.2(0x30000000), %bb.3(0x50000000)
; CHECK-NEXT: {{ $}}
- ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
- ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+ ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+ ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
; CHECK-NEXT: [[ADD64ri:%[0-9]+]]:b64 = nuw ADD64ri killed [[LD_i64_]], 2
- ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, [[ADD64ri]], 0
+ ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, -1, [[ADD64ri]], 0
; CHECK-NEXT: [[SETP_i32ri:%[0-9]+]]:b1 = SETP_i32ri [[LD_i32_]], 0, 0
; CHECK-NEXT: CBranch killed [[SETP_i32ri]], %bb.2
; CHECK-NEXT: {{ $}}
@@ -54,10 +54,10 @@ body: |
bb.0.entry:
successors: %bb.2(0x30000000), %bb.1(0x50000000)
- %5:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
- %6:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+ %5:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+ %6:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
%0:b64 = nuw ADD64ri killed %6, 2
- %1:b32 = LD_i32 0, 0, 1, 3, 32, %0, 0
+ %1:b32 = LD_i32 0, 0, 1, 3, 32, -1, %0, 0
%7:b1 = SETP_i32ri %5, 0, 0
CBranch killed %7, %bb.2
GOTO %bb.1
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
index 5b4136ec1307e..7c7c51be9567d 100644
--- a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -40,8 +40,8 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
; SM100-NEXT: ret;
- %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -93,8 +93,8 @@ define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: st.global.b16 [%rd2+28], %rs1;
; SM100-NEXT: st.global.b16 [%rd2+30], %rs2;
; SM100-NEXT: ret;
- %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) %a, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
- tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) %b, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+ %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) align 32 %a, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+ tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) align 32 %b, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
ret void
}
@@ -114,8 +114,8 @@ define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: st.global.b32 [%rd2+8], %r2;
; CHECK-NEXT: st.global.b32 [%rd2+28], %r3;
; CHECK-NEXT: ret;
- %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 16 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 16 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -150,8 +150,8 @@ define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
; SM100-NEXT: ret;
- %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -170,8 +170,8 @@ define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
; CHECK-NEXT: ret;
- %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
- tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
@@ -190,8 +190,8 @@ define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
; CHECK-NEXT: ret;
- %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
- tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
@@ -207,8 +207,8 @@ define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
; CHECK-NEXT: ret;
- %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 2, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
- tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 2 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
@@ -227,8 +227,8 @@ define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
; CHECK-NEXT: ret;
- %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
- tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -247,8 +247,8 @@ define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
; CHECK-NEXT: ret;
- %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
- tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -266,8 +266,8 @@ define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: st.global.b8 [%rd2], %rs1;
; CHECK-NEXT: st.global.b8 [%rd2+2], %rs2;
; CHECK-NEXT: ret;
- %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 2, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
- tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 2 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -299,8 +299,8 @@ define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
; SM100-NEXT: st.global.b32 [%rd3], %r1;
; SM100-NEXT: ret;
- %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
- tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
@@ -331,8 +331,8 @@ define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
; SM100-NEXT: st.global.b32 [%rd3], %r1;
; SM100-NEXT: ret;
- %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
- tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
@@ -348,19 +348,19 @@ define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-NEXT: ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
; CHECK-NEXT: st.global.b32 [%rd2], %r1;
; CHECK-NEXT: ret;
- %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
- tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
ret void
}
-declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), i32, <8 x i1>, <8 x i32>)
-declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
-declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), i32, <16 x i1>, <16 x i16>)
-declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), i32, <16 x i1>)
-declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), i32, <2 x i1>, <2 x i16>)
-declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), i32, <2 x i1>)
-declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), i32, <4 x i1>, <4 x i8>)
-declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), i32, <4 x i1>)
-declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), i32, <2 x i1>, <2 x float>)
-declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), i32, <2 x i1>)
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), <2 x i1>)
!0 = !{}
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
index 7d8f65b25bb02..9f23acaf93bc8 100644
--- a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -49,8 +49,8 @@ define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x
; CHECK-NEXT: $L__BB0_8: // %else6
; CHECK-NEXT: ret;
%a.load = load <4 x i64>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> %mask)
ret void
}
-declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
index 0935bf80b04be..feb7b7e0a0b39 100644
--- a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -34,7 +34,7 @@ define void @generic_8xi32(ptr %a, ptr %b) {
; CHECK-NEXT: st.b32 [%rd2+28], %r4;
; CHECK-NEXT: ret;
%a.load = load <8 x i32>, ptr %a
- tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -52,7 +52,7 @@ define void @generic_4xi64(ptr %a, ptr %b) {
; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
; CHECK-NEXT: ret;
%a.load = load <4 x i64>, ptr %a
- tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -72,7 +72,7 @@ define void @generic_8xfloat(ptr %a, ptr %b) {
; CHECK-NEXT: st.b32 [%rd2+28], %r4;
; CHECK-NEXT: ret;
%a.load = load <8 x float>, ptr %a
- tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -90,7 +90,7 @@ define void @generic_4xdouble(ptr %a, ptr %b) {
; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
; CHECK-NEXT: ret;
%a.load = load <4 x double>, ptr %a
- tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -124,7 +124,7 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
; SM100-NEXT: ret;
%a.load = load <8 x i32>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -153,7 +153,7 @@ define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
; SM100-NEXT: ret;
%a.load = load <4 x i64>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -185,7 +185,7 @@ define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
; SM100-NEXT: ret;
%a.load = load <8 x float>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
ret void
}
@@ -214,7 +214,7 @@ define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
; SM100-NEXT: ret;
%a.load = load <4 x double>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
ret void
}
@@ -236,17 +236,16 @@ define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b)
;
; SM100-LABEL: global_8xi32_all_mask_on(
; SM100: {
-; SM100-NEXT: .reg .b32 %r<9>;
-; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-NEXT: .reg .b64 %rd<7>;
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
-; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
-; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
-; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5};
; SM100-NEXT: ret;
%a.load = load <8 x i32>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
ret void
}
@@ -258,7 +257,7 @@ define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b)
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ret;
%a.load = load <8 x i32>, ptr addrspace(1) %a
- tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
ret void
}
@@ -304,16 +303,16 @@ define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
%7 = insertelement <8 x i32> %6, i32 poison, i32 5
%8 = insertelement <8 x i32> %7, i32 poison, i32 6
%9 = insertelement <8 x i32> %8, i32 poison, i32 7
- call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+ call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) align 32 %out, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
ret void
}
-declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
-declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
-declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
-declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, <4 x i1>)
-declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
-declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
-declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
-declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), <4 x i1>)
>From e7c3e91ab7e10380b82f4c6b27b81166f4a8de1b Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Mon, 27 Oct 2025 20:30:31 +0000
Subject: [PATCH 10/11] Review feedback
---
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 8 ++++----
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 ++++++------
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 2 +-
llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp | 12 +++++-------
4 files changed, 16 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 8e0399b493a24..5a9df54f029a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,7 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
switch (N->getOpcode()) {
case ISD::LOAD:
case ISD::ATOMIC_LOAD:
- case NVPTXISD::MLoadV1:
+ case NVPTXISD::MLoad:
if (tryLoad(N))
return;
break;
@@ -1139,8 +1139,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
case ISD::ATOMIC_LOAD:
UsedBytesMask = UINT32_MAX;
break;
- case NVPTXISD::MLoadV1:
- UsedBytesMask = N->getConstantOperandVal(N->getNumOperands() - 2);
+ case NVPTXISD::MLoad:
+ UsedBytesMask = N->getConstantOperandVal(3);
break;
default:
llvm_unreachable("Unexpected opcode");
@@ -1302,7 +1302,7 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
- case NVPTXISD::MLoadV1:
+ case NVPTXISD::MLoad:
Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
NVPTX::LD_GLOBAL_NC_i64);
break;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 73b2e930b7b8e..699498d50471c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1131,7 +1131,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::LoadV2)
MAKE_CASE(NVPTXISD::LoadV4)
MAKE_CASE(NVPTXISD::LoadV8)
- MAKE_CASE(NVPTXISD::MLoadV1)
+ MAKE_CASE(NVPTXISD::MLoad)
MAKE_CASE(NVPTXISD::LDUV2)
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
@@ -3727,13 +3727,13 @@ SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
- // The select routine does not have access to the LoadSDNode instance, so
- // pass along the extension information
+ // We currently are not lowering extending loads, but pass the extension
+ // type anyway as later handling expects it.
OtherOps.push_back(
DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
- SDValue NewLD = DAG.getMemIntrinsicNode(
- NVPTXISD::MLoadV1, DL, LD->getVTList(), OtherOps, LD->getMemoryVT(),
- LD->getMemOperand());
+ SDValue NewLD =
+ DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
+ LD->getMemoryVT(), LD->getMemOperand());
return NewLD;
}
return SDValue();
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 89bf0c290292a..3ede48eb0f7e3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -99,7 +99,7 @@ enum NodeType : unsigned {
LoadV2,
LoadV4,
LoadV8,
- MLoadV1,
+ MLoad,
LDUV2, // LDU.v2
LDUV4, // LDU.v4
StoreV2,
diff --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
index 6fa518e8d409b..ed5e943946fef 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
@@ -69,14 +69,12 @@ static bool tagInvariantLoads(Function &F) {
markLoadsAsInvariant(LI);
Changed = true;
}
- if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
- if (II->getIntrinsicID() == Intrinsic::masked_load) {
- if (isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
- markLoadsAsInvariant(II);
- Changed = true;
- }
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::masked_load &&
+ isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
+ markLoadsAsInvariant(II);
+ Changed = true;
}
- }
}
}
return Changed;
>From 5f54ae686cd187f7dda9d07b8a3aae030ca9c63b Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Mon, 27 Oct 2025 21:22:14 +0000
Subject: [PATCH 11/11] Change pragma printing to use hex
---
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 2 +-
.../test/CodeGen/NVPTX/masked-load-vectors.ll | 34 +++++++++----------
2 files changed, 18 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 8ca3cb46b5455..6f747b70100b7 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -401,7 +401,7 @@ void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
assert(Op.isImm() && "Invalid operand");
uint32_t Imm = (uint32_t)Op.getImm();
if (Imm != UINT32_MAX) {
- O << ".pragma \"used_bytes_mask " << Imm << "\";\n\t";
+ O << ".pragma \"used_bytes_mask " << format_hex(Imm, 1) << "\";\n\t";
}
}
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
index 7c7c51be9567d..8617dea310d6c 100644
--- a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -18,9 +18,9 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
-; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
-; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
; SM90-NEXT: st.global.b32 [%rd2], %r5;
@@ -35,7 +35,7 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
-; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
@@ -56,10 +56,10 @@ define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
-; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4;
-; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r7;
; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r5;
@@ -80,7 +80,7 @@ define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
-; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
; SM100-NEXT: mov.b32 {%rs1, %rs2}, %r8;
; SM100-NEXT: mov.b32 {%rs3, %rs4}, %r3;
@@ -128,9 +128,9 @@ define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
-; SM90-NEXT: .pragma "used_bytes_mask 61440";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
; SM90-NEXT: ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
-; SM90-NEXT: .pragma "used_bytes_mask 3855";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
; SM90-NEXT: ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
; SM90-NEXT: st.global.b32 [%rd2], %r5;
@@ -145,7 +145,7 @@ define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
-; SM100-NEXT: .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
@@ -164,7 +164,7 @@ define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_param_0];
-; CHECK-NEXT: .pragma "used_bytes_mask 3";
+; CHECK-NEXT: .pragma "used_bytes_mask 0x3";
; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_param_1];
; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
@@ -184,7 +184,7 @@ define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_invariant_param_0];
-; CHECK-NEXT: .pragma "used_bytes_mask 3";
+; CHECK-NEXT: .pragma "used_bytes_mask 0x3";
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_invariant_param_1];
; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
@@ -220,7 +220,7 @@ define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_param_0];
-; CHECK-NEXT: .pragma "used_bytes_mask 5";
+; CHECK-NEXT: .pragma "used_bytes_mask 0x5";
; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_param_1];
; CHECK-NEXT: st.global.b8 [%rd2], %r1;
@@ -240,7 +240,7 @@ define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_invariant_param_0];
-; CHECK-NEXT: .pragma "used_bytes_mask 5";
+; CHECK-NEXT: .pragma "used_bytes_mask 0x5";
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_invariant_param_1];
; CHECK-NEXT: st.global.b8 [%rd2], %r1;
@@ -280,7 +280,7 @@ define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
-; SM90-NEXT: .pragma "used_bytes_mask 15";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf";
; SM90-NEXT: ld.global.v2.b32 {%r1, %r2}, [%rd1];
; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_param_1];
; SM90-NEXT: st.global.b32 [%rd2], %r1;
@@ -293,7 +293,7 @@ define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
-; SM100-NEXT: .pragma "used_bytes_mask 15";
+; SM100-NEXT: .pragma "used_bytes_mask 0xf";
; SM100-NEXT: ld.global.b64 %rd2, [%rd1];
; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_param_1];
; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
@@ -312,7 +312,7 @@ define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
-; SM90-NEXT: .pragma "used_bytes_mask 15";
+; SM90-NEXT: .pragma "used_bytes_mask 0xf";
; SM90-NEXT: ld.global.nc.v2.b32 {%r1, %r2}, [%rd1];
; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_invariant_param_1];
; SM90-NEXT: st.global.b32 [%rd2], %r1;
@@ -325,7 +325,7 @@ define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
-; SM100-NEXT: .pragma "used_bytes_mask 15";
+; SM100-NEXT: .pragma "used_bytes_mask 0xf";
; SM100-NEXT: ld.global.nc.b64 %rd2, [%rd1];
; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_invariant_param_1];
; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
More information about the llvm-commits
mailing list