[llvm] [AArch64][GlobalISel] Basic SVE and fadd (PR #72976)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 21 03:20:49 PST 2023


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/72976

This appears to be the minimum needed to get SVE fadd working. It needs more testing, just putting it up to show it works OK so far.

>From bb6fd12b5d8de8f8a56b5e215117fab5116a8a2b Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 21 Nov 2023 11:18:37 +0000
Subject: [PATCH] [AArch64][GlobalISel] Basic SVE and fadd

This appears to be the minimum needed to get SVE fadd working. It needs more
testing.
---
 .../CodeGen/GlobalISel/InstructionSelect.cpp  |  3 +-
 llvm/lib/CodeGen/RegisterBankInfo.cpp         |  6 +-
 .../AArch64/AArch64GenRegisterBankInfo.def    | 12 ++--
 .../Target/AArch64/AArch64ISelLowering.cpp    | 11 +++-
 .../Target/AArch64/AArch64RegisterBanks.td    |  2 +-
 .../AArch64/GISel/AArch64CallLowering.cpp     | 14 +++--
 .../GISel/AArch64InstructionSelector.cpp      | 27 +++++---
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    |  7 ++-
 .../AArch64/GISel/AArch64RegisterBankInfo.cpp | 62 ++++++++++---------
 .../AArch64/GISel/AArch64RegisterBankInfo.h   |  6 +-
 .../Target/AMDGPU/AMDGPURegisterBankInfo.cpp  |  6 +-
 llvm/test/CodeGen/AArch64/sve-add.ll          | 12 ++++
 llvm/utils/TableGen/InfoByHwMode.cpp          |  9 ++-
 13 files changed, 109 insertions(+), 68 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-add.ll

diff --git a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
index baea773cf528e92..f04e4cdb764f2a3 100644
--- a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
@@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
     }
 
     const LLT Ty = MRI.getType(VReg);
-    if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
+    if (Ty.isValid() &&
+        TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
       reportGISelFailure(
           MF, TPC, MORE, "gisel-select",
           "VReg's low-level type and register class have different sizes", *MI);
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index 6a96bb40f56aed9..5548430d1b0ae88 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -565,9 +565,9 @@ bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
     OrigValueBitWidth =
         std::max(OrigValueBitWidth, PartMap.getHighBitIdx() + 1);
   }
-  assert(MeaningfulBitWidth.isScalable() ||
-         OrigValueBitWidth >= MeaningfulBitWidth &&
-             "Meaningful bits not covered by the mapping");
+  assert((MeaningfulBitWidth.isScalable() ||
+          OrigValueBitWidth >= MeaningfulBitWidth) &&
+         "Meaningful bits not covered by the mapping");
   APInt ValueMask(OrigValueBitWidth, 0);
   for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
     // Check that the union of the partial mappings covers the whole value,
diff --git a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
index b87421e5ee46ae5..0b3557e67240520 100644
--- a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
+++ b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
@@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
                                                    unsigned Size,
                                                    unsigned Offset) {
   unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
-  const ValueMapping &Map =
-      AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
+  const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
+      (PartialMappingIdx)FirstInBank, TypeSize::Fixed(Size))[Offset];
   return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
          Map.NumBreakDowns == 1;
 }
@@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
 }
 
 unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
-                                                             unsigned Size) {
+                                                             TypeSize Size) {
   if (RBIdx == PMI_FirstGPR) {
     if (Size <= 32)
       return 0;
@@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
     return -1;
   }
   if (RBIdx == PMI_FirstFPR) {
+    if (Size.isScalable())
+      return 3;
     if (Size <= 16)
       return 0;
     if (Size <= 32)
@@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
-                                            unsigned Size) {
+                                            TypeSize Size) {
   assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
   unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
   if (BaseIdxOffset == -1u)
@@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
-                                           unsigned SrcBankID, unsigned Size) {
+                                           unsigned SrcBankID, TypeSize Size) {
   assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d42ae4ff93a4442..2dc7ffbd7be4335 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -144,6 +144,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+cl::opt<bool> DisableSVEGISel(
+    "aarch64-disable-sve-gisel", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(true));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -25277,15 +25282,15 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
+  if (DisableSVEGISel && Inst.getType()->isScalableTy())
     return true;
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+    if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
+    if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
       return true;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba74..9e2ed356299e2bc 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
 /// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 84057ea8d2214ac..f8f321c5881b68e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -51,6 +51,8 @@
 
 using namespace llvm;
 
+extern cl::opt<bool> DisableSVEGISel;
+
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
   : CallLowering(&TLI) {}
 
@@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
       // i1 is a special case because SDAG i1 true is naturally zero extended
       // when widened using ANYEXT. We need to do it explicitly here.
       auto &Flags = CurArgInfo.Flags[0];
-      if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
-          !Flags.isZExt()) {
+      if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::Fixed(1) &&
+          !Flags.isSExt() && !Flags.isZExt()) {
         CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
       } else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
                  1) {
@@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
 
 bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
   auto &F = MF.getFunction();
-  if (F.getReturnType()->isScalableTy() ||
-      llvm::any_of(F.args(), [](const Argument &A) {
-        return A.getType()->isScalableTy();
-      }))
+  if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
+                          llvm::any_of(F.args(), [](const Argument &A) {
+                            return A.getType()->isScalableTy();
+                          })))
     return true;
   const auto &ST = MF.getSubtarget<AArch64Subtarget>();
   if (!ST.hasNEON() || !ST.hasFPARMv8()) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index bdaae4dd724d536..9ad1e30c802dad9 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
 /// Given a register bank, and size in bits, return the smallest register class
 /// that can represent that combination.
 static const TargetRegisterClass *
-getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
+getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
                       bool GetAllRegSet = false) {
   unsigned RegBankID = RB.getID();
 
   if (RegBankID == AArch64::GPRRegBankID) {
+    assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
     if (SizeInBits <= 32)
       return GetAllRegSet ? &AArch64::GPR32allRegClass
                           : &AArch64::GPR32RegClass;
@@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
   }
 
   if (RegBankID == AArch64::FPRRegBankID) {
+    if (SizeInBits.isScalable()) {
+      assert(SizeInBits == TypeSize::Scalable(128) &&
+             "Unexpected scalable register size");
+      return &AArch64::ZPRRegClass;
+    }
+
     switch (SizeInBits) {
     default:
       return nullptr;
@@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   Register SrcReg = I.getOperand(1).getReg();
   const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
   const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
-  unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
-  unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
+  TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
+  TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
 
   // Special casing for cross-bank copies of s1s. We can technically represent
   // a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   // then we can pull it into the helpers that get the appropriate class for a
   // register bank. Or make a new helper that carries along some constraint
   // information.
-  if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
-    SrcSize = DstSize = 32;
+  if (SrcRegBank != DstRegBank &&
+      (DstSize == TypeSize::Fixed(1) && SrcSize == TypeSize::Fixed(1)))
+    SrcSize = DstSize = TypeSize::Fixed(32);
 
   return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
           getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
       return false;
     }
 
-    unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
-    unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
+    TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
+    TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
     unsigned SubReg;
 
+    if (SrcSize.isScalable()) {
+      assert(DstSize.isScalable() && "Unhandled scalable copy");
+      return true;
+    }
+
     // If the source bank doesn't support a subregister copy small enough,
     // then we first need to copy to the destination bank.
     if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 000fd648595222b..e55cf5400565215 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -59,6 +59,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v4s32 = LLT::fixed_vector(4, 32);
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
+  const LLT nxv16s8 = LLT::scalable_vector(16, 8);
+  const LLT nxv8s16 = LLT::scalable_vector(8, 16);
+  const LLT nxv4s32 = LLT::scalable_vector(4, 32);
+  const LLT nxv2s64 = LLT::scalable_vector(2, 64);
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -238,7 +242,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                                G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
                                G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
                                G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
+      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv16s8, nxv8s16,
+                 nxv4s32, nxv2s64})
       .legalIf([=](const LegalityQuery &Query) {
         const auto &Ty = Query.Types[0];
         return (Ty == v8s16 || Ty == v4s16) && HasFP16;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 4ca5b3674461d89..1466570cf317a7e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -161,17 +161,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
     unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min;               \
     (void)PartialMapDstIdx;                                                    \
     (void)PartialMapSrcIdx;                                                    \
-    const ValueMapping *Map = getCopyMapping(                                  \
-        AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size);  \
+    const ValueMapping *Map =                                                  \
+        getCopyMapping(AArch64::RBNameDst##RegBankID,                          \
+                       AArch64::RBNameSrc##RegBankID, TypeSize::Fixed(Size));  \
     (void)Map;                                                                 \
     assert(Map[0].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
-           Map[0].NumBreakDowns == 1 && #RBNameDst #Size                       \
-           " Dst is incorrectly initialized");                                 \
+           Map[0].NumBreakDowns == 1 &&                                        \
+           #RBNameDst #Size " Dst is incorrectly initialized");                \
     assert(Map[1].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
-           Map[1].NumBreakDowns == 1 && #RBNameSrc #Size                       \
-           " Src is incorrectly initialized");                                 \
+           Map[1].NumBreakDowns == 1 &&                                        \
+           #RBNameSrc #Size " Src is incorrectly initialized");                \
                                                                                \
   } while (false)
 
@@ -255,6 +256,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPR_3bRegClassID:
+  case AArch64::ZPR_4bRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -299,8 +303,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
   case TargetOpcode::G_OR: {
     // 32 and 64-bit or can be mapped on either FPR or
     // GPR for the same cost.
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -320,8 +324,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_BITCAST: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -341,15 +345,13 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
         /*NumOperands*/ 2);
 
@@ -360,8 +362,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_LOAD: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -372,15 +374,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     InstructionMappings AltMappings;
     const InstructionMapping &GPRMapping = getInstructionMapping(
         /*ID*/ 1, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstGPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRMapping = getInstructionMapping(
         /*ID*/ 2, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstFPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
 
     AltMappings.push_back(&GPRMapping);
@@ -458,7 +462,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
          "This code is for instructions with 3 or less operands");
 
   LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-  unsigned Size = Ty.getSizeInBits();
+  TypeSize Size = Ty.getSizeInBits();
   bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
 
   PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
@@ -711,9 +715,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // If both RB are null that means both registers are generic.
       // We shouldn't be here.
       assert(DstRB && SrcRB && "Both RegBank were nullptr");
-      unsigned Size = getSizeInBits(DstReg, MRI, TRI);
+      TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
       return getInstructionMapping(
-          DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::Fixed(Size)),
+          DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
           getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
           // We only care about the mapping of the destination.
           /*NumOperands*/ 1);
@@ -724,7 +728,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_BITCAST: {
     LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
-    unsigned Size = DstTy.getSizeInBits();
+    TypeSize Size = DstTy.getSizeInBits();
     bool DstIsGPR = !DstTy.isVector() && DstTy.getSizeInBits() <= 64;
     bool SrcIsGPR = !SrcTy.isVector() && SrcTy.getSizeInBits() <= 64;
     const RegisterBank &DstRB =
@@ -732,7 +736,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     const RegisterBank &SrcRB =
         SrcIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
     return getInstructionMapping(
-        DefaultMappingID, copyCost(DstRB, SrcRB, TypeSize::Fixed(Size)),
+        DefaultMappingID, copyCost(DstRB, SrcRB, Size),
         getCopyMapping(DstRB.getID(), SrcRB.getID(), Size),
         // We only care about the mapping of the destination for COPY.
         /*NumOperands*/ Opc == TargetOpcode::G_BITCAST ? 2 : 1);
@@ -744,7 +748,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   unsigned NumOperands = MI.getNumOperands();
 
   // Track the size and bank of each register.  We don't do partial mappings.
-  SmallVector<unsigned, 4> OpSize(NumOperands);
+  SmallVector<TypeSize, 4> OpSize(NumOperands, TypeSize::Fixed(0));
   SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);
   for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
     auto &MO = MI.getOperand(Idx);
@@ -825,7 +829,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       Cost = copyCost(
           *AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[0]].RegBank,
           *AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[1]].RegBank,
-          TypeSize::Fixed(OpSize[0]));
+          OpSize[0]);
     break;
   case TargetOpcode::G_LOAD: {
     // Loading in vector unit is slightly more expensive.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
index b6364c6a64099a4..bfbe2e5a06177e0 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
@@ -70,7 +70,7 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
                                      PartialMappingIdx LastAlias,
                                      ArrayRef<PartialMappingIdx> Order);
 
-  static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, unsigned Size);
+  static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, TypeSize Size);
 
   /// Get the pointer to the ValueMapping representing the RegisterBank
   /// at \p RBIdx with a size of \p Size.
@@ -80,13 +80,13 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
   ///
   /// \pre \p RBIdx != PartialMappingIdx::None
   static const RegisterBankInfo::ValueMapping *
-  getValueMapping(PartialMappingIdx RBIdx, unsigned Size);
+  getValueMapping(PartialMappingIdx RBIdx, TypeSize Size);
 
   /// Get the pointer to the ValueMapping of the operands of a copy
   /// instruction from the \p SrcBankID register bank to the \p DstBankID
   /// register bank with a size of \p Size.
   static const RegisterBankInfo::ValueMapping *
-  getCopyMapping(unsigned DstBankID, unsigned SrcBankID, unsigned Size);
+  getCopyMapping(unsigned DstBankID, unsigned SrcBankID, TypeSize Size);
 
   /// Get the instruction mapping for G_FPEXT.
   ///
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index 49322109bdb74f0..2c470766f129caf 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -243,10 +243,8 @@ unsigned AMDGPURegisterBankInfo::copyCost(const RegisterBank &Dst,
   // Legalization doesn't know about the necessary context, so an s1 use may
   // have been a truncate from an arbitrary value, in which case a copy (lowered
   // as a compare with 0) needs to be inserted.
-  if (Size == 1 &&
-      (Dst.getID() == AMDGPU::SGPRRegBankID) &&
-      (isVectorRegisterBank(Src) ||
-       Src.getID() == AMDGPU::SGPRRegBankID ||
+  if (Size == 1 && (Dst.getID() == AMDGPU::SGPRRegBankID) &&
+      (isVectorRegisterBank(Src) || Src.getID() == AMDGPU::SGPRRegBankID ||
        Src.getID() == AMDGPU::VCCRegBankID))
     return std::numeric_limits<unsigned>::max();
 
diff --git a/llvm/test/CodeGen/AArch64/sve-add.ll b/llvm/test/CodeGen/AArch64/sve-add.ll
new file mode 100644
index 000000000000000..2dc1aca7c3be511
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-add.ll
@@ -0,0 +1,12 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc < %s -mtriple aarch64 -global-isel -mattr=+sve -aarch64-disable-sve-gisel=false | FileCheck %s
+
+define <vscale x 4 x float> @add(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
+; CHECK-LABEL: add:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+entry:
+  %c = fadd <vscale x 4 x float> %a, %b
+  ret <vscale x 4 x float> %c
+}
diff --git a/llvm/utils/TableGen/InfoByHwMode.cpp b/llvm/utils/TableGen/InfoByHwMode.cpp
index 7e4ab5346621879..4a2f7bb0e29e68d 100644
--- a/llvm/utils/TableGen/InfoByHwMode.cpp
+++ b/llvm/utils/TableGen/InfoByHwMode.cpp
@@ -129,14 +129,13 @@ bool RegSizeInfo::operator< (const RegSizeInfo &I) const {
 }
 
 bool RegSizeInfo::isSubClassOf(const RegSizeInfo &I) const {
-  return RegSize <= I.RegSize &&
-         SpillAlignment && I.SpillAlignment % SpillAlignment == 0 &&
-         SpillSize <= I.SpillSize;
+  return RegSize <= I.RegSize && SpillAlignment &&
+         I.SpillAlignment % SpillAlignment == 0 && SpillSize <= I.SpillSize;
 }
 
 void RegSizeInfo::writeToStream(raw_ostream &OS) const {
-  OS << "[R=" << RegSize << ",S=" << SpillSize
-     << ",A=" << SpillAlignment << ']';
+  OS << "[R=" << RegSize << ",S=" << SpillSize << ",A=" << SpillAlignment
+     << ']';
 }
 
 RegSizeInfoByHwMode::RegSizeInfoByHwMode(Record *R,



More information about the llvm-commits mailing list