[llvm] [NVPTX] Lower LLVM masked vector loads and stores to PTX (PR #159387)
Drew Kersnar via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 21 09:22:17 PDT 2025
================
@@ -3098,10 +3189,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
+static std::tuple<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+ SDValue Chain = N->getOperand(0);
+ SDValue BasePtr = N->getOperand(1);
+ SDValue Mask = N->getOperand(3);
+ SDValue Passthru = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ResVT = N->getValueType(0);
+ assert(ResVT.isVector() && "Masked vector load must have vector type");
+ // While we only expect poison passthru vectors as an input to the backend,
+ // when the legalization framework splits a poison vector in half, it creates
+ // two undef vectors, so we can technically expect those too.
+ assert((Passthru.getOpcode() == ISD::POISON ||
+ Passthru.getOpcode() == ISD::UNDEF) &&
+ "Passthru operand expected to be poison or undef");
+
+ // Extract the mask and convert it to a uint32_t representing the used bytes
+ // of the entire vector load
+ uint32_t UsedBytesMask = 0;
+ uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+ assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+ uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+ uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+ for (unsigned I :
+ llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ // We technically only want to do this shift for every iteration *but* the
+ // first, but in the first iteration NewMask is 0, so this shift is a
+ // no-op.
+ UsedBytesMask <<= ElementSizeInBytes;
+
+ if (Mask->getConstantOperandVal(I) != 0)
+ UsedBytesMask |= ElementMask;
+ }
----------------
dakersnar wrote:
Actually, I think a comment like "// Mask elements must be constants" before the call to getZExtVal is sufficient. Also, we need to call llvm::reverse on Mask->ops(). Otherwise, looks good, thank you!
https://github.com/llvm/llvm-project/pull/159387
More information about the llvm-commits
mailing list