[llvm] [NVPTX] Generalize and extend upsizing when lowering 8/16-bit-element vector loads/stores (PR #119622)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 14:06:42 PST 2024


================
@@ -160,6 +163,76 @@ static bool Is16bitsType(MVT VT) {
           VT.SimpleTy == MVT::i16);
 }
 
+// When legalizing vector loads/stores, this function is called, which does two
+// things:
+// 1. Determines Whether the vector is something we want to custom lower,
+// std::nullopt is returned if we do not want to custom lower it.
+// 2. If we do want to handle it, returns three parameters:
+//    - unsigned int NumElts - The number of elements in the final vector
+//    - EVT EltVT - The type of the elements in the final vector
+//    - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
+//    the vector
+static std::optional<std::tuple<unsigned int, EVT, bool>>
+tryGetVectorLoweringParams(EVT ValVT) {
+  // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
+  // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
+  // vectorized loads/stores with the actual element type for i8/i16 as that
+  // would require v8/v16 variants that do not exist.
+  // In order to load/store such vectors efficiently, here in Type Legalization,
+  // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
+  // lower to PTX as vectors of b32.
+  bool UpsizeElementTypes = false;
+
+  if (!ValVT.isVector() || !ValVT.isSimple())
+    return std::nullopt;
+
+  EVT EltVT = ValVT.getVectorElementType();
+  unsigned NumElts = ValVT.getVectorNumElements();
+
+  // We only handle "native" vector sizes for now, e.g. <4 x double> is not
+  // legal.  We can (and should) split that into 2 stores of <2 x double> here
+  // but I'm leaving that as a TODO for now.
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    return std::nullopt;
+  case MVT::v2i8:
+  case MVT::v2i16:
+  case MVT::v2i32:
+  case MVT::v2i64:
+  case MVT::v2f16:
+  case MVT::v2bf16:
+  case MVT::v2f32:
+  case MVT::v2f64:
+  case MVT::v4i8:
+  case MVT::v4i16:
+  case MVT::v4i32:
+  case MVT::v4f16:
+  case MVT::v4bf16:
+  case MVT::v4f32:
+    // This is a "native" vector type
+    break;
+  case MVT::v8i8:   // <2 x i8x4>
+  case MVT::v8f16:  // <4 x f16x2>
+  case MVT::v8bf16: // <4 x bf16x2>
+  case MVT::v8i16:  // <4 x i16x2>
+  case MVT::v16i8:  // <4 x i8x4>
+    // This can be upsized into a "native" vector type
+    UpsizeElementTypes = true;
+    break;
+  }
+
+  if (UpsizeElementTypes) {
+    // Number of elements to pack in one word.
+    unsigned NPerWord = 32 / EltVT.getSizeInBits();
+    // Word-sized vector.
+    EltVT = MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord);
+    // Number of word-sized vectors.
+    NumElts = NumElts / NPerWord;
+  }
+
+  return std::tuple(NumElts, EltVT, UpsizeElementTypes);
----------------
Artem-B wrote:

And this becomes `llvm_unreachable()`

https://github.com/llvm/llvm-project/pull/119622


More information about the llvm-commits mailing list