[llvm] [AArch64][GlobalISel] Basic SVE and fadd (PR #72976)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 5 00:20:50 PST 2024
https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/72976
>From 9900a4dbfb39d7fa46592a9589a24acf51c514ff Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Fri, 5 Jan 2024 08:20:37 +0000
Subject: [PATCH] [AArch64][GlobalISel] Basic SVE and fadd
This appears to be the minimum needed to get SVE fadd working.
---
.../CodeGen/GlobalISel/InstructionSelect.cpp | 3 +-
.../AArch64/AArch64GenRegisterBankInfo.def | 12 ++--
.../Target/AArch64/AArch64ISelLowering.cpp | 11 +++-
.../Target/AArch64/AArch64RegisterBanks.td | 2 +-
.../AArch64/GISel/AArch64CallLowering.cpp | 14 ++--
.../GISel/AArch64InstructionSelector.cpp | 27 ++++++--
.../AArch64/GISel/AArch64LegalizerInfo.cpp | 6 +-
.../AArch64/GISel/AArch64RegisterBankInfo.cpp | 64 ++++++++++---------
.../AArch64/GISel/AArch64RegisterBankInfo.h | 6 +-
llvm/test/CodeGen/AArch64/sve-add.ll | 33 ++++++++++
10 files changed, 120 insertions(+), 58 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-add.ll
diff --git a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
index baea773cf528e9..f04e4cdb764f2a 100644
--- a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
@@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
}
const LLT Ty = MRI.getType(VReg);
- if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
+ if (Ty.isValid() &&
+ TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
reportGISelFailure(
MF, TPC, MORE, "gisel-select",
"VReg's low-level type and register class have different sizes", *MI);
diff --git a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
index b87421e5ee46ae..f8a1387941fe8f 100644
--- a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
+++ b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
@@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
unsigned Size,
unsigned Offset) {
unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
- const ValueMapping &Map =
- AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
+ const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
+ (PartialMappingIdx)FirstInBank, TypeSize::getFixed(Size))[Offset];
return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
Map.NumBreakDowns == 1;
}
@@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
}
unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
- unsigned Size) {
+ TypeSize Size) {
if (RBIdx == PMI_FirstGPR) {
if (Size <= 32)
return 0;
@@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
return -1;
}
if (RBIdx == PMI_FirstFPR) {
+ if (Size.isScalable())
+ return 3;
if (Size <= 16)
return 0;
if (Size <= 32)
@@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
- unsigned Size) {
+ TypeSize Size) {
assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
if (BaseIdxOffset == -1u)
@@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx
const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
- unsigned SrcBankID, unsigned Size) {
+ unsigned SrcBankID, TypeSize Size) {
assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 102fd0c3dae2ab..65493611e3af77 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -145,6 +145,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
cl::desc("Maximum of xors"));
+cl::opt<bool> DisableSVEGISel(
+ "aarch64-disable-sve-gisel", cl::Hidden,
+ cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+ cl::init(true));
+
/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;
@@ -25423,15 +25428,15 @@ bool AArch64TargetLowering::shouldLocalize(
}
bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
- if (Inst.getType()->isScalableTy())
+ if (DisableSVEGISel && Inst.getType()->isScalableTy())
return true;
for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
- if (Inst.getOperand(i)->getType()->isScalableTy())
+ if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
return true;
if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
- if (AI->getAllocatedType()->isScalableTy())
+ if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
return true;
}
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba7..9e2ed356299e2b 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
/// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
/// Conditional register: NZCV.
def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 84057ea8d2214a..77146503b079b9 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -51,6 +51,8 @@
using namespace llvm;
+extern cl::opt<bool> DisableSVEGISel;
+
AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
: CallLowering(&TLI) {}
@@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
// i1 is a special case because SDAG i1 true is naturally zero extended
// when widened using ANYEXT. We need to do it explicitly here.
auto &Flags = CurArgInfo.Flags[0];
- if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
- !Flags.isZExt()) {
+ if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::getFixed(1) &&
+ !Flags.isSExt() && !Flags.isZExt()) {
CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
} else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
1) {
@@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
auto &F = MF.getFunction();
- if (F.getReturnType()->isScalableTy() ||
- llvm::any_of(F.args(), [](const Argument &A) {
- return A.getType()->isScalableTy();
- }))
+ if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
+ llvm::any_of(F.args(), [](const Argument &A) {
+ return A.getType()->isScalableTy();
+ })))
return true;
const auto &ST = MF.getSubtarget<AArch64Subtarget>();
if (!ST.hasNEON() || !ST.hasFPARMv8()) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index a4ace6cce46342..9fa750685cca96 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
/// Given a register bank, and size in bits, return the smallest register class
/// that can represent that combination.
static const TargetRegisterClass *
-getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
+getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
bool GetAllRegSet = false) {
unsigned RegBankID = RB.getID();
if (RegBankID == AArch64::GPRRegBankID) {
+ assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
if (SizeInBits <= 32)
return GetAllRegSet ? &AArch64::GPR32allRegClass
: &AArch64::GPR32RegClass;
@@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
}
if (RegBankID == AArch64::FPRRegBankID) {
+ if (SizeInBits.isScalable()) {
+ assert(SizeInBits == TypeSize::getScalable(128) &&
+ "Unexpected scalable register size");
+ return &AArch64::ZPRRegClass;
+ }
+
switch (SizeInBits) {
default:
return nullptr;
@@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
Register SrcReg = I.getOperand(1).getReg();
const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
- unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
- unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
+ TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
+ TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
// Special casing for cross-bank copies of s1s. We can technically represent
// a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
// then we can pull it into the helpers that get the appropriate class for a
// register bank. Or make a new helper that carries along some constraint
// information.
- if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
- SrcSize = DstSize = 32;
+ if (SrcRegBank != DstRegBank &&
+ (DstSize == TypeSize::getFixed(1) && SrcSize == TypeSize::getFixed(1)))
+ SrcSize = DstSize = TypeSize::getFixed(32);
return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
return false;
}
- unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
- unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
+ TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
+ TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
unsigned SubReg;
+ if (SrcSize.isScalable()) {
+ assert(DstSize.isScalable() && "Unhandled scalable copy");
+ return true;
+ }
+
// If the source bank doesn't support a subregister copy small enough,
// then we first need to copy to the destination bank.
if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 470742cdc30e6f..1f9a8169e83f73 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -59,6 +59,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
const LLT v4s32 = LLT::fixed_vector(4, 32);
const LLT v2s64 = LLT::fixed_vector(2, 64);
const LLT v2p0 = LLT::fixed_vector(2, p0);
+ const LLT nxv8s16 = LLT::scalable_vector(8, 16);
+ const LLT nxv4s32 = LLT::scalable_vector(4, 32);
+ const LLT nxv2s64 = LLT::scalable_vector(2, 64);
std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
v16s8, v8s16, v4s32,
@@ -238,7 +241,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
- .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
+ .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv8s16, nxv4s32,
+ nxv2s64})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[0];
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index b8e5e7bbdaba77..0b99d6babae7b6 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -162,17 +162,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min; \
(void)PartialMapDstIdx; \
(void)PartialMapSrcIdx; \
- const ValueMapping *Map = getCopyMapping( \
- AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size); \
+ const ValueMapping *Map = getCopyMapping(AArch64::RBNameDst##RegBankID, \
+ AArch64::RBNameSrc##RegBankID, \
+ TypeSize::getFixed(Size)); \
(void)Map; \
assert(Map[0].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] && \
- Map[0].NumBreakDowns == 1 && #RBNameDst #Size \
- " Dst is incorrectly initialized"); \
+ Map[0].NumBreakDowns == 1 && \
+ #RBNameDst #Size " Dst is incorrectly initialized"); \
assert(Map[1].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] && \
- Map[1].NumBreakDowns == 1 && #RBNameSrc #Size \
- " Src is incorrectly initialized"); \
+ Map[1].NumBreakDowns == 1 && \
+ #RBNameSrc #Size " Src is incorrectly initialized"); \
\
} while (false)
@@ -256,6 +257,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::QQRegClassID:
case AArch64::QQQRegClassID:
case AArch64::QQQQRegClassID:
+ case AArch64::ZPR_3bRegClassID:
+ case AArch64::ZPR_4bRegClassID:
+ case AArch64::ZPRRegClassID:
return getRegBank(AArch64::FPRRegBankID);
case AArch64::GPR32commonRegClassID:
case AArch64::GPR32RegClassID:
@@ -300,8 +304,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
case TargetOpcode::G_OR: {
// 32 and 64-bit or can be mapped on either FPR or
// GPR for the same cost.
- unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
- if (Size != 32 && Size != 64)
+ TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+ if (Size != TypeSize::getFixed(32) && Size != TypeSize::getFixed(64))
break;
// If the instruction has any implicit-defs or uses,
@@ -321,8 +325,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_BITCAST: {
- unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
- if (Size != 32 && Size != 64)
+ TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+ if (Size != TypeSize::getFixed(32) && Size != TypeSize::getFixed(64))
break;
// If the instruction has any implicit-defs or uses,
@@ -341,16 +345,12 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
/*NumOperands*/ 2);
const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
/*ID*/ 3,
- /*Cost*/
- copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
- TypeSize::getFixed(Size)),
+ /*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
/*NumOperands*/ 2);
const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
/*ID*/ 3,
- /*Cost*/
- copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
- TypeSize::getFixed(Size)),
+ /*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
/*NumOperands*/ 2);
@@ -361,8 +361,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_LOAD: {
- unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
- if (Size != 64)
+ TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+ if (Size != TypeSize::getFixed(64))
break;
// If the instruction has any implicit-defs or uses,
@@ -373,15 +373,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
InstructionMappings AltMappings;
const InstructionMapping &GPRMapping = getInstructionMapping(
/*ID*/ 1, /*Cost*/ 1,
- getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
- // Addresses are GPR 64-bit.
- getValueMapping(PMI_FirstGPR, 64)}),
+ getOperandsMapping(
+ {getValueMapping(PMI_FirstGPR, Size),
+ // Addresses are GPR 64-bit.
+ getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);
const InstructionMapping &FPRMapping = getInstructionMapping(
/*ID*/ 2, /*Cost*/ 1,
- getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
- // Addresses are GPR 64-bit.
- getValueMapping(PMI_FirstGPR, 64)}),
+ getOperandsMapping(
+ {getValueMapping(PMI_FirstFPR, Size),
+ // Addresses are GPR 64-bit.
+ getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);
AltMappings.push_back(&GPRMapping);
@@ -459,7 +461,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
"This code is for instructions with 3 or less operands");
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
- unsigned Size = Ty.getSizeInBits();
+ TypeSize Size = Ty.getSizeInBits();
bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
@@ -719,9 +721,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// If both RB are null that means both registers are generic.
// We shouldn't be here.
assert(DstRB && SrcRB && "Both RegBank were nullptr");
- unsigned Size = getSizeInBits(DstReg, MRI, TRI);
+ TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
return getInstructionMapping(
- DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::getFixed(Size)),
+ DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
// We only care about the mapping of the destination.
/*NumOperands*/ 1);
@@ -732,7 +734,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case TargetOpcode::G_BITCAST: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
- unsigned Size = DstTy.getSizeInBits();
+ TypeSize Size = DstTy.getSizeInBits();
bool DstIsGPR = !DstTy.isVector() && DstTy.getSizeInBits() <= 64;
bool SrcIsGPR = !SrcTy.isVector() && SrcTy.getSizeInBits() <= 64;
const RegisterBank &DstRB =
@@ -740,7 +742,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
const RegisterBank &SrcRB =
SrcIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
return getInstructionMapping(
- DefaultMappingID, copyCost(DstRB, SrcRB, TypeSize::getFixed(Size)),
+ DefaultMappingID, copyCost(DstRB, SrcRB, Size),
getCopyMapping(DstRB.getID(), SrcRB.getID(), Size),
// We only care about the mapping of the destination for COPY.
/*NumOperands*/ Opc == TargetOpcode::G_BITCAST ? 2 : 1);
@@ -752,7 +754,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
unsigned NumOperands = MI.getNumOperands();
// Track the size and bank of each register. We don't do partial mappings.
- SmallVector<unsigned, 4> OpSize(NumOperands);
+ SmallVector<TypeSize, 4> OpSize(NumOperands, TypeSize::getFixed(0));
SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);
for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
auto &MO = MI.getOperand(Idx);
@@ -833,7 +835,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
Cost = copyCost(
*AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[0]].RegBank,
*AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[1]].RegBank,
- TypeSize::getFixed(OpSize[0]));
+ OpSize[0]);
break;
case TargetOpcode::G_LOAD: {
// Loading in vector unit is slightly more expensive.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
index b6364c6a64099a..bfbe2e5a06177e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
@@ -70,7 +70,7 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
PartialMappingIdx LastAlias,
ArrayRef<PartialMappingIdx> Order);
- static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, unsigned Size);
+ static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, TypeSize Size);
/// Get the pointer to the ValueMapping representing the RegisterBank
/// at \p RBIdx with a size of \p Size.
@@ -80,13 +80,13 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
///
/// \pre \p RBIdx != PartialMappingIdx::None
static const RegisterBankInfo::ValueMapping *
- getValueMapping(PartialMappingIdx RBIdx, unsigned Size);
+ getValueMapping(PartialMappingIdx RBIdx, TypeSize Size);
/// Get the pointer to the ValueMapping of the operands of a copy
/// instruction from the \p SrcBankID register bank to the \p DstBankID
/// register bank with a size of \p Size.
static const RegisterBankInfo::ValueMapping *
- getCopyMapping(unsigned DstBankID, unsigned SrcBankID, unsigned Size);
+ getCopyMapping(unsigned DstBankID, unsigned SrcBankID, TypeSize Size);
/// Get the instruction mapping for G_FPEXT.
///
diff --git a/llvm/test/CodeGen/AArch64/sve-add.ll b/llvm/test/CodeGen/AArch64/sve-add.ll
new file mode 100644
index 00000000000000..828f1e7342d86d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-add.ll
@@ -0,0 +1,33 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc < %s -mtriple aarch64 -mattr=+sve -aarch64-disable-sve-gisel=false | FileCheck %s
+; RUN: llc < %s -mtriple aarch64 -mattr=+sve -global-isel -aarch64-disable-sve-gisel=false | FileCheck %s
+
+define <vscale x 2 x double> @addnxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
+; CHECK-LABEL: addnxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.d, z0.d, z1.d
+; CHECK-NEXT: ret
+entry:
+ %c = fadd <vscale x 2 x double> %a, %b
+ ret <vscale x 2 x double> %c
+}
+
+define <vscale x 4 x float> @addnxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
+; CHECK-LABEL: addnxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+entry:
+ %c = fadd <vscale x 4 x float> %a, %b
+ ret <vscale x 4 x float> %c
+}
+
+define <vscale x 8 x half> @addnxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: addnxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.h, z0.h, z1.h
+; CHECK-NEXT: ret
+entry:
+ %c = fadd <vscale x 8 x half> %a, %b
+ ret <vscale x 8 x half> %c
+}
More information about the llvm-commits
mailing list