[llvm] [NVPTX] support packed f32 instructions for sm_100+ (PR #126337)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 22 15:24:02 PDT 2025


================
@@ -5051,26 +5080,395 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
-                                         std::size_t Back) {
+/// OverrideVT - allows overriding result and memory type
+static std::optional<std::pair<SDValue, SDValue>>
+convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector,
+                  std::optional<EVT> OverrideVT = std::nullopt) {
+  EVT ResVT = N->getValueType(0);
+  if (OverrideVT)
+    ResVT = *OverrideVT;
+  SDLoc DL(N);
+
+  assert(ResVT.isVector() && "Vector load must have vector type");
+
+  auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+  if (!NumEltsAndEltVT)
+    return std::nullopt;
+  auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+
+  Align Alignment = LD->getAlign();
+  auto &TD = DAG.getDataLayout();
+  Align PrefAlign = TD.getPrefTypeAlign(
+      OverrideVT.value_or(LD->getMemoryVT()).getTypeForEVT(*DAG.getContext()));
+  if (Alignment < PrefAlign) {
+    // This load is not sufficiently aligned, so bail out and let this vector
+    // load be scalarized.  Note that we may still be able to emit smaller
+    // vector loads.  For example, if we are loading a <4 x float> with an
+    // alignment of 8, this check will fail but the legalizer will try again
+    // with 2 x <2 x float>, which will succeed with an alignment of 8.
+    return std::nullopt;
+  }
+
+  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
+  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
+  // loaded type to i16 and propagate the "real" type as the memory type.
+  bool NeedTrunc = false;
+  if (EltVT.getSizeInBits() < 16) {
+    EltVT = MVT::i16;
+    NeedTrunc = true;
+  }
+
+  unsigned Opcode = 0;
+  SDVTList LdResVTs;
+
+  switch (NumElts) {
+  default:
+    return std::nullopt;
+  case 2:
+    Opcode = NVPTXISD::LoadV2;
+    LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
+    break;
+  case 4: {
+    Opcode = NVPTXISD::LoadV4;
+    EVT ListVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
+    LdResVTs = DAG.getVTList(ListVTs);
+    break;
+  }
+  }
+
+  // Copy regular operands
+  SmallVector<SDValue, 8> OtherOps(N->ops());
+
+  // The select routine does not have access to the LoadSDNode instance, so
+  // pass along the extension information
+  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+
+  SDValue NewLD = DAG.getMemIntrinsicNode(
+      Opcode, DL, LdResVTs, OtherOps, OverrideVT.value_or(LD->getMemoryVT()),
+      LD->getMemOperand());
+
+  SDValue LoadChain = NewLD.getValue(NumElts);
+
+  if (BuildVector) {
+    SmallVector<SDValue> ScalarRes;
+    assert(NumElts <= ResVT.getVectorNumElements() &&
+           "NumElts should not increase, only decrease or stay the same.");
+    if (NumElts < ResVT.getVectorNumElements()) {
+      // If the number of elements has decreased, getVectorLoweringShape has
+      // upsized the element types
+      assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
+             EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
+      // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
+      // into individual elements.
+      for (unsigned i = 0; i < NumElts; ++i) {
+        SDValue SubVector = NewLD.getValue(i);
+        DAG.ExtractVectorElements(SubVector, ScalarRes);
+      }
+    } else {
+      for (unsigned i = 0; i < NumElts; ++i) {
+        SDValue Res = NewLD.getValue(i);
+        if (NeedTrunc)
+          Res =
+              DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
+        ScalarRes.push_back(Res);
+      }
+    }
+
+    SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes);
+    return {{BuildVec, LoadChain}};
+  }
+
+  return {{NewLD, LoadChain}};
+}
+
+static SDValue PerformLoadCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+  auto *MemN = cast<MemSDNode>(N);
+  // only operate on vectors of f32s / i64s
----------------
AlexMaclean wrote:

Why is this?

https://github.com/llvm/llvm-project/pull/126337


More information about the llvm-commits mailing list