[llvm] aa7eace - [TableGen][GlobalISel] Account for HwMode in RegisterBank register sizes

Nitin John Raj via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 2 23:18:45 PDT 2023


Author: Nitin John Raj
Date: 2023-06-02T23:14:17-07:00
New Revision: aa7eace8431ba213c5431638b894b0e1b4b481c7

URL: https://github.com/llvm/llvm-project/commit/aa7eace8431ba213c5431638b894b0e1b4b481c7
DIFF: https://github.com/llvm/llvm-project/commit/aa7eace8431ba213c5431638b894b0e1b4b481c7.diff

LOG: [TableGen][GlobalISel] Account for HwMode in RegisterBank register sizes

This patch adds logic for determining RegisterBank size to RegisterBankInfo, which allows accounting for the HwMode of the target. Individual RegisterBanks cannot be constructed with HwMode information as construction is generated by TableGen, but a RegisterBankInfo subclass can provide the HwMode as a constructor argument. The HwMode is used to select the appropriate RegisterBank size from an array relating sizes to RegisterBanks.

Targets simply need to provide the HwMode argument to the <target>GenRegisterBankInfo constructor. The RISC-V RegisterBankInfo constructor has been updated accordingly (plus an unused argument removed).

Reviewed By: simoncook, craig.topper

Differential Revision: https://reviews.llvm.org/D76007

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/RegisterBank.h
    llvm/include/llvm/CodeGen/RegisterBankInfo.h
    llvm/lib/CodeGen/MachineVerifier.cpp
    llvm/lib/CodeGen/RegisterBank.cpp
    llvm/lib/CodeGen/RegisterBankInfo.cpp
    llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
    llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
    llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
    llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h
    llvm/lib/Target/RISCV/RISCVSubtarget.cpp
    llvm/lib/Target/X86/X86RegisterBankInfo.cpp
    llvm/utils/TableGen/RegisterBankEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/RegisterBank.h b/llvm/include/llvm/CodeGen/RegisterBank.h
index 66885f113e8ee..ee295c7cdde00 100644
--- a/llvm/include/llvm/CodeGen/RegisterBank.h
+++ b/llvm/include/llvm/CodeGen/RegisterBank.h
@@ -29,7 +29,6 @@ class RegisterBank {
 private:
   unsigned ID;
   const char *Name;
-  unsigned Size;
   BitVector ContainedRegClasses;
 
   /// Sentinel value used to recognize register bank not properly
@@ -40,8 +39,8 @@ class RegisterBank {
   friend RegisterBankInfo;
 
 public:
-  RegisterBank(unsigned ID, const char *Name, unsigned Size,
-               const uint32_t *CoveredClasses, unsigned NumRegClasses);
+  RegisterBank(unsigned ID, const char *Name, const uint32_t *CoveredClasses,
+               unsigned NumRegClasses);
 
   /// Get the identifier of this register bank.
   unsigned getID() const { return ID; }
@@ -50,9 +49,6 @@ class RegisterBank {
   /// Should be used only for debugging purposes.
   const char *getName() const { return Name; }
 
-  /// Get the maximal size in bits that fits in this register bank.
-  unsigned getSize() const { return Size; }
-
   /// Check whether this instance is ready to be used.
   bool isValid() const;
 
@@ -62,7 +58,7 @@ class RegisterBank {
   /// \note This method does not check anything when assertions are disabled.
   ///
   /// \return True is the check was successful.
-  bool verify(const TargetRegisterInfo &TRI) const;
+  bool verify(const RegisterBankInfo &RBI, const TargetRegisterInfo &TRI) const;
 
   /// Check whether this register bank covers \p RC.
   /// In other words, check if this register bank fully covers

diff  --git a/llvm/include/llvm/CodeGen/RegisterBankInfo.h b/llvm/include/llvm/CodeGen/RegisterBankInfo.h
index f0aaf378bb8df..60f03756e1b54 100644
--- a/llvm/include/llvm/CodeGen/RegisterBankInfo.h
+++ b/llvm/include/llvm/CodeGen/RegisterBankInfo.h
@@ -20,6 +20,7 @@
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/CodeGen/LowLevelType.h"
 #include "llvm/CodeGen/Register.h"
+#include "llvm/CodeGen/RegisterBank.h"
 #include "llvm/Support/ErrorHandling.h"
 #include <cassert>
 #include <initializer_list>
@@ -30,7 +31,6 @@ namespace llvm {
 class MachineInstr;
 class MachineRegisterInfo;
 class raw_ostream;
-class RegisterBank;
 class TargetInstrInfo;
 class TargetRegisterClass;
 class TargetRegisterInfo;
@@ -83,7 +83,7 @@ class RegisterBankInfo {
     /// \note This method does not check anything when assertions are disabled.
     ///
     /// \return True is the check was successful.
-    bool verify() const;
+    bool verify(const RegisterBankInfo &RBI) const;
   };
 
   /// Helper struct that represents how a value is mapped through
@@ -175,7 +175,7 @@ class RegisterBankInfo {
     /// \note This method does not check anything when assertions are disabled.
     ///
     /// \return True is the check was successful.
-    bool verify(unsigned MeaningfulBitWidth) const;
+    bool verify(const RegisterBankInfo &RBI, unsigned MeaningfulBitWidth) const;
 
     /// Print this on dbgs() stream.
     void dump() const;
@@ -384,11 +384,17 @@ class RegisterBankInfo {
 
 protected:
   /// Hold the set of supported register banks.
-  RegisterBank **RegBanks;
+  const RegisterBank **RegBanks;
 
   /// Total number of register banks.
   unsigned NumRegBanks;
 
+  /// Hold the sizes of the register banks for all HwModes.
+  const unsigned *Sizes;
+
+  /// Current HwMode for the target.
+  unsigned HwMode;
+
   /// Keep dynamically allocated PartialMapping in a separate map.
   /// This shouldn't be needed when everything gets TableGen'ed.
   mutable DenseMap<unsigned, std::unique_ptr<const PartialMapping>>
@@ -415,7 +421,8 @@ class RegisterBankInfo {
 
   /// Create a RegisterBankInfo that can accommodate up to \p NumRegBanks
   /// RegisterBank instances.
-  RegisterBankInfo(RegisterBank **RegBanks, unsigned NumRegBanks);
+  RegisterBankInfo(const RegisterBank **RegBanks, unsigned NumRegBanks,
+                   const unsigned *Sizes, unsigned HwMode);
 
   /// This constructor is meaningless.
   /// It just provides a default constructor that can be used at link time
@@ -428,7 +435,7 @@ class RegisterBankInfo {
   }
 
   /// Get the register bank identified by \p ID.
-  RegisterBank &getRegBank(unsigned ID) {
+  const RegisterBank &getRegBank(unsigned ID) {
     assert(ID < getNumRegBanks() && "Accessing an unknown register bank");
     return *RegBanks[ID];
   }
@@ -576,6 +583,11 @@ class RegisterBankInfo {
     return const_cast<RegisterBankInfo *>(this)->getRegBank(ID);
   }
 
+  /// Get the maximum size in bits that fits in the given register bank.
+  unsigned getMaximumSize(unsigned RegBankID) const {
+    return Sizes[RegBankID + HwMode * NumRegBanks];
+  }
+
   /// Get the register bank of \p Reg.
   /// If Reg has not been assigned a register, a register class,
   /// or a register bank, then this returns nullptr.

diff  --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index 8c5c7320b5094..f960869ec28b0 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -2174,6 +2174,7 @@ MachineVerifier::visitMachineOperand(const MachineOperand *MO, unsigned MONum) {
           }
 
           const RegisterBank *RegBank = MRI->getRegBankOrNull(Reg);
+          const RegisterBankInfo *RBI = MF->getSubtarget().getRegBankInfo();
 
           // If we're post-RegBankSelect, the gvreg must have a bank.
           if (!RegBank && isFunctionRegBankSelected) {
@@ -2185,12 +2186,12 @@ MachineVerifier::visitMachineOperand(const MachineOperand *MO, unsigned MONum) {
 
           // Make sure the register fits into its register bank if any.
           if (RegBank && Ty.isValid() &&
-              RegBank->getSize() < Ty.getSizeInBits()) {
+              RBI->getMaximumSize(RegBank->getID()) < Ty.getSizeInBits()) {
             report("Register bank is too small for virtual register", MO,
                    MONum);
             errs() << "Register bank " << RegBank->getName() << " too small("
-                   << RegBank->getSize() << ") to fit " << Ty.getSizeInBits()
-                   << "-bits\n";
+                   << RBI->getMaximumSize(RegBank->getID()) << ") to fit "
+                   << Ty.getSizeInBits() << "-bits\n";
             return;
           }
         }

diff  --git a/llvm/lib/CodeGen/RegisterBank.cpp b/llvm/lib/CodeGen/RegisterBank.cpp
index 512b21aeacafc..8e0a0b0dc2824 100644
--- a/llvm/lib/CodeGen/RegisterBank.cpp
+++ b/llvm/lib/CodeGen/RegisterBank.cpp
@@ -11,6 +11,7 @@
 
 #include "llvm/CodeGen/RegisterBank.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/CodeGen/RegisterBankInfo.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/Config/llvm-config.h"
 #include "llvm/Support/Debug.h"
@@ -21,15 +22,16 @@ using namespace llvm;
 
 const unsigned RegisterBank::InvalidID = UINT_MAX;
 
-RegisterBank::RegisterBank(
-    unsigned ID, const char *Name, unsigned Size,
-    const uint32_t *CoveredClasses, unsigned NumRegClasses)
-    : ID(ID), Name(Name), Size(Size) {
+RegisterBank::RegisterBank(unsigned ID, const char *Name,
+                           const uint32_t *CoveredClasses,
+                           unsigned NumRegClasses)
+    : ID(ID), Name(Name) {
   ContainedRegClasses.resize(NumRegClasses);
   ContainedRegClasses.setBitsInMask(CoveredClasses);
 }
 
-bool RegisterBank::verify(const TargetRegisterInfo &TRI) const {
+bool RegisterBank::verify(const RegisterBankInfo &RBI,
+                          const TargetRegisterInfo &TRI) const {
   assert(isValid() && "Invalid register bank");
   for (unsigned RCId = 0, End = TRI.getNumRegClasses(); RCId != End; ++RCId) {
     const TargetRegisterClass &RC = *TRI.getRegClass(RCId);
@@ -50,7 +52,7 @@ bool RegisterBank::verify(const TargetRegisterInfo &TRI) const {
 
       // Verify that the Size of the register bank is big enough to cover
       // all the register classes it covers.
-      assert(getSize() >= TRI.getRegSizeInBits(SubRC) &&
+      assert(RBI.getMaximumSize(getID()) >= TRI.getRegSizeInBits(SubRC) &&
              "Size is not big enough for all the subclasses!");
       assert(covers(SubRC) && "Not all subclasses are covered");
     }
@@ -64,7 +66,7 @@ bool RegisterBank::covers(const TargetRegisterClass &RC) const {
 }
 
 bool RegisterBank::isValid() const {
-  return ID != InvalidID && Name != nullptr && Size != 0 &&
+  return ID != InvalidID && Name != nullptr &&
          // A register bank that does not cover anything is useless.
          !ContainedRegClasses.empty();
 }
@@ -89,7 +91,7 @@ void RegisterBank::print(raw_ostream &OS, bool IsForDebug,
   OS << getName();
   if (!IsForDebug)
     return;
-  OS << "(ID:" << getID() << ", Size:" << getSize() << ")\n"
+  OS << "(ID:" << getID() << ")\n"
      << "isValid:" << isValid() << '\n'
      << "Number of Covered register classes: " << ContainedRegClasses.count()
      << '\n';

diff  --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index 58f76c29122b0..b3f9faaca5285 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -52,9 +52,11 @@ const unsigned RegisterBankInfo::InvalidMappingID = UINT_MAX - 1;
 //------------------------------------------------------------------------------
 // RegisterBankInfo implementation.
 //------------------------------------------------------------------------------
-RegisterBankInfo::RegisterBankInfo(RegisterBank **RegBanks,
-                                   unsigned NumRegBanks)
-    : RegBanks(RegBanks), NumRegBanks(NumRegBanks) {
+RegisterBankInfo::RegisterBankInfo(const RegisterBank **RegBanks,
+                                   unsigned NumRegBanks, const unsigned *Sizes,
+                                   unsigned HwMode)
+    : RegBanks(RegBanks), NumRegBanks(NumRegBanks), Sizes(Sizes),
+      HwMode(HwMode) {
 #ifndef NDEBUG
   for (unsigned Idx = 0, End = getNumRegBanks(); Idx != End; ++Idx) {
     assert(RegBanks[Idx] != nullptr && "Invalid RegisterBank");
@@ -70,7 +72,7 @@ bool RegisterBankInfo::verify(const TargetRegisterInfo &TRI) const {
     assert(Idx == RegBank.getID() &&
            "ID does not match the index in the array");
     LLVM_DEBUG(dbgs() << "Verify " << RegBank << '\n');
-    assert(RegBank.verify(TRI) && "RegBank is invalid");
+    assert(RegBank.verify(*this, TRI) && "RegBank is invalid");
   }
 #endif // NDEBUG
   return true;
@@ -516,12 +518,14 @@ LLVM_DUMP_METHOD void RegisterBankInfo::PartialMapping::dump() const {
 }
 #endif
 
-bool RegisterBankInfo::PartialMapping::verify() const {
+bool RegisterBankInfo::PartialMapping::verify(
+    const RegisterBankInfo &RBI) const {
   assert(RegBank && "Register bank not set");
   assert(Length && "Empty mapping");
   assert((StartIdx <= getHighBitIdx()) && "Overflow, switch to APInt?");
   // Check if the minimum width fits into RegBank.
-  assert(RegBank->getSize() >= Length && "Register bank too small for Mask");
+  assert(RBI.getMaximumSize(RegBank->getID()) >= Length &&
+         "Register bank too small for Mask");
   return true;
 }
 
@@ -546,13 +550,14 @@ bool RegisterBankInfo::ValueMapping::partsAllUniform() const {
   return true;
 }
 
-bool RegisterBankInfo::ValueMapping::verify(unsigned MeaningfulBitWidth) const {
+bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
+                                            unsigned MeaningfulBitWidth) const {
   assert(NumBreakDowns && "Value mapped nowhere?!");
   unsigned OrigValueBitWidth = 0;
   for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
     // Check that each register bank is big enough to hold the partial value:
     // this check is done by PartialMapping::verify
-    assert(PartMap.verify() && "Partial mapping is invalid");
+    assert(PartMap.verify(RBI) && "Partial mapping is invalid");
     // The original value should completely be mapped.
     // Thus the maximum accessed index + 1 is the size of the original value.
     OrigValueBitWidth =
@@ -626,8 +631,9 @@ bool RegisterBankInfo::InstructionMapping::verify(
     (void)MOMapping;
     // Register size in bits.
     // This size must match what the mapping expects.
-    assert(MOMapping.verify(RBI->getSizeInBits(
-               Reg, MF.getRegInfo(), *MF.getSubtarget().getRegisterInfo())) &&
+    assert(MOMapping.verify(*RBI, RBI->getSizeInBits(
+                                      Reg, MF.getRegInfo(),
+                                      *MF.getSubtarget().getRegisterInfo())) &&
            "Value mapping is invalid");
   }
   return true;

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 0dbfb4c743b0c..0314a3b65ebdd 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -71,7 +71,8 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
     // GR64all + its subclasses.
     assert(RBGPR.covers(*TRI.getRegClass(AArch64::GPR32RegClassID)) &&
            "Subclass not added?");
-    assert(RBGPR.getSize() == 128 && "GPRs should hold up to 128-bit");
+    assert(getMaximumSize(RBGPR.getID()) == 128 &&
+           "GPRs should hold up to 128-bit");
 
     // The FPR register bank is fully defined by all the registers in
     // GR64all + its subclasses.
@@ -79,12 +80,13 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
            "Subclass not added?");
     assert(RBFPR.covers(*TRI.getRegClass(AArch64::FPR64RegClassID)) &&
            "Subclass not added?");
-    assert(RBFPR.getSize() == 512 &&
+    assert(getMaximumSize(RBFPR.getID()) == 512 &&
            "FPRs should hold up to 512-bit via QQQQ sequence");
 
     assert(RBCCR.covers(*TRI.getRegClass(AArch64::CCRRegClassID)) &&
            "Class not added?");
-    assert(RBCCR.getSize() == 32 && "CCR should hold up to 32-bit");
+    assert(getMaximumSize(RBCCR.getID()) == 32 &&
+           "CCR should hold up to 32-bit");
 
     // Check that the TableGen'ed like file is in sync we our expectations.
     // First, the Idx.

diff  --git a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
index 527fefbd291ea..f7977941e8951 100644
--- a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
@@ -162,7 +162,8 @@ ARMRegisterBankInfo::ARMRegisterBankInfo(const TargetRegisterInfo &TRI) {
            "Subclass not added?");
     assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) &&
            "Subclass not added?");
-    assert(RBGPR.getSize() == 32 && "GPRs should hold up to 32-bit");
+    assert(getMaximumSize(RBGPR.getID()) == 32 &&
+           "GPRs should hold up to 32-bit");
 
 #ifndef NDEBUG
     ARM::checkPartialMappings();

diff  --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index 5b208856c5325..9b601902ad20b 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -22,4 +22,5 @@
 
 using namespace llvm;
 
-RISCVRegisterBankInfo::RISCVRegisterBankInfo(const TargetRegisterInfo &TRI) {}
+RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
+    : RISCVGenRegisterBankInfo(HwMode) {}

diff  --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h
index 7cd692e8cc292..ee6d4db278809 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h
@@ -31,7 +31,7 @@ class RISCVGenRegisterBankInfo : public RegisterBankInfo {
 /// This class provides the information for the target register banks.
 class RISCVRegisterBankInfo final : public RISCVGenRegisterBankInfo {
 public:
-  RISCVRegisterBankInfo(const TargetRegisterInfo &TRI);
+  RISCVRegisterBankInfo(unsigned HwMode);
 };
 } // end namespace llvm
 #endif

diff  --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index f05753e61c363..eec2e7359eda6 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -86,7 +86,7 @@ RISCVSubtarget::RISCVSubtarget(const Triple &TT, StringRef CPU,
   CallLoweringInfo.reset(new RISCVCallLowering(*getTargetLowering()));
   Legalizer.reset(new RISCVLegalizerInfo(*this));
 
-  auto *RBI = new RISCVRegisterBankInfo(*getRegisterInfo());
+  auto *RBI = new RISCVRegisterBankInfo(getHwMode());
   RegBankInfo.reset(RBI);
   InstSelector.reset(createRISCVInstructionSelector(
       *static_cast<const RISCVTargetMachine *>(&TM), *this, *RBI));

diff  --git a/llvm/lib/Target/X86/X86RegisterBankInfo.cpp b/llvm/lib/Target/X86/X86RegisterBankInfo.cpp
index 733db70f14a2e..3160969e81e4d 100644
--- a/llvm/lib/Target/X86/X86RegisterBankInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterBankInfo.cpp
@@ -36,7 +36,8 @@ X86RegisterBankInfo::X86RegisterBankInfo(const TargetRegisterInfo &TRI) {
   // GR64 + its subclasses.
   assert(RBGPR.covers(*TRI.getRegClass(X86::GR64RegClassID)) &&
          "Subclass not added?");
-  assert(RBGPR.getSize() == 64 && "GPRs should hold up to 64-bit");
+  assert(getMaximumSize(RBGPR.getID()) == 64 &&
+         "GPRs should hold up to 64-bit");
 }
 
 const RegisterBank &

diff  --git a/llvm/utils/TableGen/RegisterBankEmitter.cpp b/llvm/utils/TableGen/RegisterBankEmitter.cpp
index c5ba6a897c77e..2d23bf86b6ad8 100644
--- a/llvm/utils/TableGen/RegisterBankEmitter.cpp
+++ b/llvm/utils/TableGen/RegisterBankEmitter.cpp
@@ -37,11 +37,11 @@ class RegisterBank {
   RegisterClassesTy RCs;
 
   /// The register class with the largest register size.
-  const CodeGenRegisterClass *RCWithLargestRegsSize;
+  std::vector<const CodeGenRegisterClass *> RCsWithLargestRegSize;
 
 public:
-  RegisterBank(const Record &TheDef)
-      : TheDef(TheDef), RCWithLargestRegsSize(nullptr) {}
+  RegisterBank(const Record &TheDef, unsigned NumModeIds)
+      : TheDef(TheDef), RCsWithLargestRegSize(NumModeIds) {}
 
   /// Get the human-readable name for the bank.
   StringRef getName() const { return TheDef.getValueAsString("Name"); }
@@ -79,18 +79,21 @@ class RegisterBank {
     //        register size anywhere (we could sum the sizes of the subregisters
     //        but there may be additional bits too) and we can't derive it from
     //        the VT's reliably due to Untyped.
-    if (RCWithLargestRegsSize == nullptr)
-      RCWithLargestRegsSize = RC;
-    else if (RCWithLargestRegsSize->RSI.get(DefaultMode).SpillSize <
-             RC->RSI.get(DefaultMode).SpillSize)
-      RCWithLargestRegsSize = RC;
-    assert(RCWithLargestRegsSize && "RC was nullptr?");
+    unsigned NumModeIds = RCsWithLargestRegSize.size();
+    for (unsigned M = 0; M < NumModeIds; ++M) {
+      if (RCsWithLargestRegSize[M] == nullptr)
+        RCsWithLargestRegSize[M] = RC;
+      else if (RCsWithLargestRegSize[M]->RSI.get(M).SpillSize <
+               RC->RSI.get(M).SpillSize)
+        RCsWithLargestRegSize[M] = RC;
+      assert(RCsWithLargestRegSize[M] && "RC was nullptr?");
+    }
 
     RCs.emplace_back(RC);
   }
 
-  const CodeGenRegisterClass *getRCWithLargestRegsSize() const {
-    return RCWithLargestRegsSize;
+  const CodeGenRegisterClass *getRCWithLargestRegSize(unsigned HwMode) const {
+    return RCsWithLargestRegSize[HwMode];
   }
 
   iterator_range<typename RegisterClassesTy::const_iterator>
@@ -144,9 +147,10 @@ void RegisterBankEmitter::emitBaseClassDefinition(
     raw_ostream &OS, const StringRef TargetName,
     const std::vector<RegisterBank> &Banks) {
   OS << "private:\n"
-     << "  static RegisterBank *RegBanks[];\n\n"
+     << "  static const RegisterBank *RegBanks[];\n"
+     << "  static const unsigned Sizes[];\n\n"
      << "protected:\n"
-     << "  " << TargetName << "GenRegisterBankInfo();\n"
+     << "  " << TargetName << "GenRegisterBankInfo(unsigned HwMode = 0);\n"
      << "\n";
 }
 
@@ -211,6 +215,7 @@ void RegisterBankEmitter::emitBaseClassImplementation(
     raw_ostream &OS, StringRef TargetName,
     std::vector<RegisterBank> &Banks) {
   const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank();
+  const CodeGenHwModes &CGH = Target.getHwModes();
 
   OS << "namespace llvm {\n"
      << "namespace " << TargetName << " {\n";
@@ -241,11 +246,8 @@ void RegisterBankEmitter::emitBaseClassImplementation(
   for (const auto &Bank : Banks) {
     std::string QualifiedBankID =
         (TargetName + "::" + Bank.getEnumeratorName()).str();
-    const CodeGenRegisterClass &RC = *Bank.getRCWithLargestRegsSize();
-    unsigned Size = RC.RSI.get(DefaultMode).SpillSize;
-    OS << "RegisterBank " << Bank.getInstanceVarName() << "(/* ID */ "
-       << QualifiedBankID << ", /* Name */ \"" << Bank.getName()
-       << "\", /* Size */ " << Size << ", "
+    OS << "const RegisterBank " << Bank.getInstanceVarName() << "(/* ID */ "
+       << QualifiedBankID << ", /* Name */ \"" << Bank.getName() << "\", "
        << "/* CoveredRegClasses */ " << Bank.getCoverageArrayName()
        << ", /* NumRegClasses */ "
        << RegisterClassHierarchy.getRegClasses().size() << ");\n";
@@ -253,16 +255,33 @@ void RegisterBankEmitter::emitBaseClassImplementation(
   OS << "} // end namespace " << TargetName << "\n"
      << "\n";
 
-  OS << "RegisterBank *" << TargetName
+  OS << "const RegisterBank *" << TargetName
      << "GenRegisterBankInfo::RegBanks[] = {\n";
   for (const auto &Bank : Banks)
     OS << "    &" << TargetName << "::" << Bank.getInstanceVarName() << ",\n";
   OS << "};\n\n";
 
+  unsigned NumModeIds = CGH.getNumModeIds();
+  OS << "const unsigned " << TargetName << "GenRegisterBankInfo::Sizes[] = {\n";
+  for (unsigned M = 0; M < NumModeIds; ++M) {
+    OS << "    // Mode = " << M << " (";
+    if (M == DefaultMode)
+      OS << "Default";
+    else
+      OS << CGH.getMode(M).Name;
+    OS << ")\n";
+    for (const auto &Bank : Banks) {
+      const CodeGenRegisterClass &RC = *Bank.getRCWithLargestRegSize(M);
+      unsigned Size = RC.RSI.get(M).SpillSize;
+      OS << "    " << Size << ",\n";
+    }
+  }
+  OS << "};\n\n";
+
   OS << TargetName << "GenRegisterBankInfo::" << TargetName
-     << "GenRegisterBankInfo()\n"
+     << "GenRegisterBankInfo(unsigned HwMode)\n"
      << "    : RegisterBankInfo(RegBanks, " << TargetName
-     << "::NumRegisterBanks) {\n"
+     << "::NumRegisterBanks, Sizes, HwMode) {\n"
      << "  // Assert that RegBank indices match their ID's\n"
      << "#ifndef NDEBUG\n"
      << "  for (auto RB : enumerate(RegBanks))\n"
@@ -275,12 +294,13 @@ void RegisterBankEmitter::emitBaseClassImplementation(
 void RegisterBankEmitter::run(raw_ostream &OS) {
   StringRef TargetName = Target.getName();
   const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank();
+  const CodeGenHwModes &CGH = Target.getHwModes();
 
   Records.startTimer("Analyze records");
   std::vector<RegisterBank> Banks;
   for (const auto &V : Records.getAllDerivedDefinitions("RegisterBank")) {
     SmallPtrSet<const CodeGenRegisterClass *, 8> VisitedRCs;
-    RegisterBank Bank(*V);
+    RegisterBank Bank(*V, CGH.getNumModeIds());
 
     for (const CodeGenRegisterClass *RC :
          Bank.getExplicitlySpecifiedRegisterClasses(RegisterClassHierarchy)) {


        


More information about the llvm-commits mailing list