[llvm] [WIP] - [LLVM][SVE] Honour NEON calling convention when targeting SVE VLS. (PR #70847)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 31 11:17:39 PDT 2023
https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/70847
None
>From 681cb3550d5463e1a1853bcab3a0edc0fcc0f435 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Tue, 31 Oct 2023 17:44:31 +0000
Subject: [PATCH] [LLVM][SVE] Honour NEON calling convention when targeting SVE
VLS.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 75 +++++++++++++++++++
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 10 +++
2 files changed, 85 insertions(+)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d00db82c9e49ac2..b64e0560059274d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26542,3 +26542,78 @@ bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const {
}
return true;
}
+
+MVT AArch64TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
+ EVT VT) const {
+ if (!VT.isFixedLengthVector() || VT.getFixedSizeInBits() <= 128)
+ return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
+
+ EVT VT1;
+ MVT RegisterVT;
+ unsigned NumIntermediates;
+ getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
+ NumIntermediates, RegisterVT);
+ return RegisterVT;
+}
+
+unsigned AArch64TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
+ EVT VT) const {
+ if (!VT.isFixedLengthVector() || VT.getFixedSizeInBits() <= 128)
+ return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
+
+ EVT VT1;
+ MVT VT2;
+ unsigned NumIntermediates;
+ return getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
+ NumIntermediates, VT2);
+}
+
+unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT, unsigned &NumIntermediates, MVT &RegisterVT) const {
+ int NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
+ if (!RegisterVT.isFixedLengthVector() ||
+ RegisterVT.getFixedSizeInBits() <= 128)
+ return NumRegs;
+
+ // SVE VLS support does not introduce a new ABI so force the use of NEON sized
+ // vectr types for call arguments and returns.
+
+ assert(Subtarget->useSVEForFixedLengthVectors() && "oops!");
+ assert(IntermediateVT == RegisterVT && "Unexpected VT mismatch!");
+ assert(RegisterVT.getFixedSizeInBits() % 128 == 0 && "Unexpected size!");
+
+ unsigned NumSubRegs = RegisterVT.getFixedSizeInBits() / 128;
+
+ switch (RegisterVT.getVectorElementType().SimpleTy) {
+ default:
+ llvm_unreachable("unexpected element type for vector");
+ case MVT::i8:
+ IntermediateVT = RegisterVT = MVT::v16i8;
+ break;
+ case MVT::i16:
+ IntermediateVT = RegisterVT = MVT::v8i16;
+ break;
+ case MVT::i32:
+ IntermediateVT = RegisterVT = MVT::v4i32;
+ break;
+ case MVT::i64:
+ IntermediateVT = RegisterVT = MVT::v2i64;
+ break;
+ case MVT::f16:
+ IntermediateVT = RegisterVT = MVT::v8i16;
+ break;
+ case MVT::f32:
+ IntermediateVT = RegisterVT = MVT::v4f32;
+ break;
+ case MVT::f64:
+ IntermediateVT = RegisterVT = MVT::v2f64;
+ break;
+ case MVT::bf16:
+ IntermediateVT = RegisterVT = MVT::v8bf16;
+ break;
+ }
+
+ NumIntermediates *= NumSubRegs;
+ return NumRegs * NumSubRegs;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 7332a95615a4da5..2e53481c03e1616 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -946,6 +946,16 @@ class AArch64TargetLowering : public TargetLowering {
// used for 64bit and 128bit vectors as well.
bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
+ // Follow NEON ABI rules even when using SVE for fixed length vectors.
+ MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
+ EVT VT) const override;
+ unsigned getNumRegistersForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
+ EVT VT) const override;
+ unsigned getVectorTypeBreakdownForCallingConv(
+ LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
+ unsigned &NumIntermediates, MVT &RegisterVT) const override;
+
private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for different targets.
More information about the llvm-commits
mailing list