[llvm] [NVPTX] Generalize and extend upsizing when lowering 8/16-bit-element vector loads/stores (PR #119622)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 11 14:40:39 PST 2024


================
@@ -3223,32 +3262,30 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
     case 4:
       Opcode = NVPTXISD::StoreV4;
       break;
-    case 8:
-      // v8f16 is a special case. PTX doesn't have st.v8.f16
-      // instruction. Instead, we split the vector into v2f16 chunks and
-      // store them with st.v4.b32.
-      assert(Is16bitsType(EltVT.getSimpleVT()) && "Wrong type for the vector.");
-      Opcode = NVPTXISD::StoreV4;
-      StoreF16x2 = true;
-      break;
     }
 
     SmallVector<SDValue, 8> Ops;
 
     // First is the chain
     Ops.push_back(N->getOperand(0));
 
-    if (StoreF16x2) {
-      // Combine f16,f16 -> v2f16
-      NumElts /= 2;
+    if (UpsizeElementTypes) {
+      // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
+      // stored as b32s
+      unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
       for (unsigned i = 0; i < NumElts; ++i) {
-        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
-                                 DAG.getIntPtrConstant(i * 2, DL));
-        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
-                                 DAG.getIntPtrConstant(i * 2 + 1, DL));
-        EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
-        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
-        Ops.push_back(V2);
+        SmallVector<SDValue, 8> Elts;
+        for (unsigned j = 0; j < NumEltsPerSubVector; ++j) {
+          SDValue E = DAG.getNode(
+              ISD::EXTRACT_VECTOR_ELT, DL, EltVT.getVectorElementType(), Val,
+              DAG.getIntPtrConstant(i * NumEltsPerSubVector + j, DL));
+          Elts.push_back(E);
+        }
+        EVT VecVT =
+            EVT::getVectorVT(*DAG.getContext(), EltVT.getVectorElementType(),
+                             NumEltsPerSubVector);
+        SDValue SubVector = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Elts);
----------------
AlexMaclean wrote:

Can this be replaced with [getBuildVector](https://llvm.org/doxygen/classllvm_1_1SelectionDAG.html#a947d455919dd654cced16d8ff4f8399c)?

https://github.com/llvm/llvm-project/pull/119622


More information about the llvm-commits mailing list