[llvm] [LLVM][SVE] Honour calling convention when using SVE for fixed length vectors. (PR #70847)

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 27 15:52:25 PST 2023


================
@@ -26718,3 +26718,99 @@ bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const {
 unsigned AArch64TargetLowering::getMinimumJumpTableEntries() const {
   return Subtarget->getMinimumJumpTableEntries();
 }
+
+MVT AArch64TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+                                                         CallingConv::ID CC,
+                                                         EVT VT) const {
+  bool NonUnitFixedLengthVector =
+      VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
+  if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
+    return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
+
+  EVT VT1;
+  MVT RegisterVT;
+  unsigned NumIntermediates;
+  getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1, NumIntermediates,
+                                       RegisterVT);
+  return RegisterVT;
+}
+
+unsigned AArch64TargetLowering::getNumRegistersForCallingConv(
+    LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
+  bool NonUnitFixedLengthVector =
+      VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
+  if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
+    return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
+
+  EVT VT1;
+  MVT VT2;
+  unsigned NumIntermediates;
+  return getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
+                                              NumIntermediates, VT2);
+}
+
+unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(
+    LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
+    unsigned &NumIntermediates, MVT &RegisterVT) const {
+  int NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
+      Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
+  if (!RegisterVT.isFixedLengthVector() ||
+      RegisterVT.getFixedSizeInBits() <= 128)
+    return NumRegs;
+
+  assert(Subtarget->useSVEForFixedLengthVectors() && "Unexpected mode!");
+  assert(IntermediateVT == RegisterVT && "Unexpected VT mismatch!");
+  assert(RegisterVT.getFixedSizeInBits() % 128 == 0 && "Unexpected size!");
+
+  // A size mismatch here implies either type promotion or widening and would
+  // have resulted in scalarisation if larger vectors had not be available.
+  if (RegisterVT.getSizeInBits() * NumRegs != VT.getSizeInBits()) {
+    EVT EltTy = VT.getVectorElementType();
+    EVT NewVT = EVT::getVectorVT(Context, EltTy, ElementCount::getFixed(1));
+    if (!isTypeLegal(NewVT))
+      NewVT = EltTy;
+
+    IntermediateVT = NewVT;
+    NumIntermediates = VT.getVectorNumElements();
+    RegisterVT = getRegisterType(Context, NewVT);
+    return NumIntermediates;
+  }
+
+  // SVE VLS support does not introduce a new ABI so we should use NEON sized
+  // types for vector arguments and returns.
+
+  unsigned NumSubRegs = RegisterVT.getFixedSizeInBits() / 128;
+  NumIntermediates *= NumSubRegs;
+  NumRegs *= NumSubRegs;
+
+  switch (RegisterVT.getVectorElementType().SimpleTy) {
+  default:
+    llvm_unreachable("unexpected element type for vector");
+  case MVT::i8:
+    IntermediateVT = RegisterVT = MVT::v16i8;
+    break;
+  case MVT::i16:
+    IntermediateVT = RegisterVT = MVT::v8i16;
+    break;
+  case MVT::i32:
+    IntermediateVT = RegisterVT = MVT::v4i32;
+    break;
+  case MVT::i64:
+    IntermediateVT = RegisterVT = MVT::v2i64;
+    break;
+  case MVT::f16:
+    IntermediateVT = RegisterVT = MVT::v8i16;
----------------
davemgreen wrote:

->v8f16?

https://github.com/llvm/llvm-project/pull/70847


More information about the llvm-commits mailing list