[llvm] [NVPTX] Lower LLVM masked vector loads and stores to PTX (PR #159387)
    Drew Kersnar via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Tue Oct 21 09:22:17 PDT 2025
    
    
  
================
@@ -3098,10 +3189,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
+static std::tuple<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+  SDValue Chain = N->getOperand(0);
+  SDValue BasePtr = N->getOperand(1);
+  SDValue Mask = N->getOperand(3);
+  SDValue Passthru = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+  assert(ResVT.isVector() && "Masked vector load must have vector type");
+  // While we only expect poison passthru vectors as an input to the backend,
+  // when the legalization framework splits a poison vector in half, it creates
+  // two undef vectors, so we can technically expect those too.
+  assert((Passthru.getOpcode() == ISD::POISON ||
+          Passthru.getOpcode() == ISD::UNDEF) &&
+         "Passthru operand expected to be poison or undef");
+
+  // Extract the mask and convert it to a uint32_t representing the used bytes
+  // of the entire vector load
+  uint32_t UsedBytesMask = 0;
+  uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+  assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+  uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+  uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+  for (unsigned I :
+       llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    // We technically only want to do this shift for every iteration *but* the
+    // first, but in the first iteration NewMask is 0, so this shift is a
+    // no-op.
+    UsedBytesMask <<= ElementSizeInBytes;
+
+    if (Mask->getConstantOperandVal(I) != 0)
+      UsedBytesMask |= ElementMask;
+  }
----------------
dakersnar wrote:
Actually, I think a comment like "// Mask elements must be constants" before the call to getZExtVal is sufficient. Also, we need to call llvm::reverse on Mask->ops(). Otherwise, looks good, thank you!
https://github.com/llvm/llvm-project/pull/159387
    
    
More information about the llvm-commits
mailing list