[llvm] [NVPTX] Generalize and extend upsizing when lowering 8/16-bit-element vector loads/stores (PR #119622)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 16 14:06:42 PST 2024
================
@@ -160,6 +163,76 @@ static bool Is16bitsType(MVT VT) {
VT.SimpleTy == MVT::i16);
}
+// When legalizing vector loads/stores, this function is called, which does two
+// things:
+// 1. Determines Whether the vector is something we want to custom lower,
+// std::nullopt is returned if we do not want to custom lower it.
+// 2. If we do want to handle it, returns three parameters:
+// - unsigned int NumElts - The number of elements in the final vector
+// - EVT EltVT - The type of the elements in the final vector
+// - bool UpsizeElementTypes - Whether or not we are upsizing the elements of
+// the vector
+static std::optional<std::tuple<unsigned int, EVT, bool>>
+tryGetVectorLoweringParams(EVT ValVT) {
+ // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
+ // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
+ // vectorized loads/stores with the actual element type for i8/i16 as that
+ // would require v8/v16 variants that do not exist.
+ // In order to load/store such vectors efficiently, here in Type Legalization,
+ // we split the vector into word-sized chunks (v2x16/v4i8). Later, we will
+ // lower to PTX as vectors of b32.
+ bool UpsizeElementTypes = false;
+
+ if (!ValVT.isVector() || !ValVT.isSimple())
+ return std::nullopt;
+
+ EVT EltVT = ValVT.getVectorElementType();
+ unsigned NumElts = ValVT.getVectorNumElements();
+
+ // We only handle "native" vector sizes for now, e.g. <4 x double> is not
+ // legal. We can (and should) split that into 2 stores of <2 x double> here
+ // but I'm leaving that as a TODO for now.
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ return std::nullopt;
+ case MVT::v2i8:
+ case MVT::v2i16:
+ case MVT::v2i32:
+ case MVT::v2i64:
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ case MVT::v2f32:
+ case MVT::v2f64:
+ case MVT::v4i8:
+ case MVT::v4i16:
+ case MVT::v4i32:
+ case MVT::v4f16:
+ case MVT::v4bf16:
+ case MVT::v4f32:
+ // This is a "native" vector type
+ break;
+ case MVT::v8i8: // <2 x i8x4>
+ case MVT::v8f16: // <4 x f16x2>
+ case MVT::v8bf16: // <4 x bf16x2>
+ case MVT::v8i16: // <4 x i16x2>
+ case MVT::v16i8: // <4 x i8x4>
+ // This can be upsized into a "native" vector type
+ UpsizeElementTypes = true;
+ break;
+ }
+
+ if (UpsizeElementTypes) {
+ // Number of elements to pack in one word.
+ unsigned NPerWord = 32 / EltVT.getSizeInBits();
+ // Word-sized vector.
+ EltVT = MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord);
+ // Number of word-sized vectors.
+ NumElts = NumElts / NPerWord;
+ }
+
+ return std::tuple(NumElts, EltVT, UpsizeElementTypes);
----------------
Artem-B wrote:
And this becomes `llvm_unreachable()`
https://github.com/llvm/llvm-project/pull/119622
More information about the llvm-commits
mailing list