[llvm] 859b05b - [RDF] Allow RegisterRef to contain register unit

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 9 06:19:48 PDT 2023


Author: Krzysztof Parzyszek
Date: 2023-06-09T06:19:03-07:00
New Revision: 859b05b02d3fd9ab6b77f2bed8df6902fe704806

URL: https://github.com/llvm/llvm-project/commit/859b05b02d3fd9ab6b77f2bed8df6902fe704806
DIFF: https://github.com/llvm/llvm-project/commit/859b05b02d3fd9ab6b77f2bed8df6902fe704806.diff

LOG: [RDF] Allow RegisterRef to contain register unit

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/RDFGraph.h
    llvm/include/llvm/CodeGen/RDFRegisters.h
    llvm/lib/CodeGen/RDFGraph.cpp
    llvm/lib/CodeGen/RDFLiveness.cpp
    llvm/lib/CodeGen/RDFRegisters.cpp
    llvm/lib/Target/Hexagon/RDFCopy.cpp
    llvm/lib/Target/Hexagon/RDFCopy.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/RDFGraph.h b/llvm/include/llvm/CodeGen/RDFGraph.h
index 97673a249237c..79ff8bc2d9b50 100644
--- a/llvm/include/llvm/CodeGen/RDFGraph.h
+++ b/llvm/include/llvm/CodeGen/RDFGraph.h
@@ -884,7 +884,7 @@ NodeAddr<RefNode *> RefNode::getNextRef(RegisterRef RR, Predicate P,
   while (NA.Addr != this) {
     if (NA.Addr->getType() == NodeAttrs::Ref) {
       NodeAddr<RefNode *> RA = NA;
-      if (RA.Addr->getRegRef(G) == RR && P(NA))
+      if (G.getPRI().equal_to(RA.Addr->getRegRef(G), RR) && P(NA))
         return NA;
       if (NextOnly)
         break;

diff  --git a/llvm/include/llvm/CodeGen/RDFRegisters.h b/llvm/include/llvm/CodeGen/RDFRegisters.h
index 86c00aeb47666..72fd59c82c368 100644
--- a/llvm/include/llvm/CodeGen/RDFRegisters.h
+++ b/llvm/include/llvm/CodeGen/RDFRegisters.h
@@ -14,6 +14,7 @@
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/MC/LaneBitmask.h"
+#include "llvm/MC/MCRegister.h"
 #include <cassert>
 #include <cstdint>
 #include <map>
@@ -26,6 +27,7 @@ class MachineFunction;
 class raw_ostream;
 
 namespace rdf {
+struct RegisterAggr;
 
 using RegisterId = uint32_t;
 
@@ -70,36 +72,51 @@ template <typename T, unsigned N = 32> struct IndexedSet {
 
 struct RegisterRef {
   RegisterId Reg = 0;
-  LaneBitmask Mask = LaneBitmask::getNone();
+  LaneBitmask Mask = LaneBitmask::getNone(); // Only for registers.
 
-  RegisterRef() = default;
-  explicit RegisterRef(RegisterId R, LaneBitmask M = LaneBitmask::getAll())
-      : Reg(R), Mask(R != 0 ? M : LaneBitmask::getNone()) {}
+  constexpr RegisterRef() = default;
+  constexpr explicit RegisterRef(RegisterId R,
+                                 LaneBitmask M = LaneBitmask::getAll())
+      : Reg(R), Mask(isRegId(R) && R != 0 ? M : LaneBitmask::getNone()) {}
 
-  operator bool() const { return Reg != 0 && Mask.any(); }
+  // Classify null register as a "register".
+  constexpr bool isReg() const { return Reg == 0 || isRegId(Reg); }
+  constexpr bool isUnit() const { return isUnitId(Reg); }
+  constexpr bool isMask() const { return isMaskId(Reg); }
 
-  bool operator==(const RegisterRef &RR) const {
-    return Reg == RR.Reg && Mask == RR.Mask;
+  constexpr operator bool() const {
+    return !isReg() || (Reg != 0 && Mask.any());
   }
 
-  bool operator!=(const RegisterRef &RR) const { return !operator==(RR); }
+  constexpr size_t hash() const {
+    return std::hash<RegisterId>{}(Reg) ^
+           std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
+  }
 
-  bool operator<(const RegisterRef &RR) const {
-    return Reg < RR.Reg || (Reg == RR.Reg && Mask < RR.Mask);
+  static constexpr bool isRegId(unsigned Id) {
+    return Register::isPhysicalRegister(Id);
+  }
+  static constexpr bool isUnitId(unsigned Id) {
+    return Register::isVirtualRegister(Id);
+  }
+  static constexpr bool isMaskId(unsigned Id) {
+    return Register::isStackSlot(Id);
   }
 
-  size_t hash() const {
-    return std::hash<RegisterId>{}(Reg) ^
-           std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
+  static RegisterId toUnitId(unsigned Idx) {
+    return Register::index2VirtReg(Idx);
   }
+  static unsigned toRegUnit(RegisterId U) { return Register::virtReg2Index(U); }
+
+  bool operator<(RegisterRef) const = delete;
+  bool operator==(RegisterRef) const = delete;
+  bool operator!=(RegisterRef) const = delete;
 };
 
 struct PhysicalRegisterInfo {
   PhysicalRegisterInfo(const TargetRegisterInfo &tri,
                        const MachineFunction &mf);
 
-  static bool isRegMaskId(RegisterId R) { return Register::isStackSlot(R); }
-
   RegisterId getRegMaskId(const uint32_t *RM) const {
     return Register::index2StackSlot(RegMasks.find(RM));
   }
@@ -109,11 +126,13 @@ struct PhysicalRegisterInfo {
   }
 
   bool alias(RegisterRef RA, RegisterRef RB) const {
-    if (!isRegMaskId(RA.Reg))
-      return !isRegMaskId(RB.Reg) ? aliasRR(RA, RB) : aliasRM(RA, RB);
-    return !isRegMaskId(RB.Reg) ? aliasRM(RB, RA) : aliasMM(RA, RB);
+    if (!RA.isMask())
+      return !RB.isMask() ? aliasRR(RA, RB) : aliasRM(RA, RB);
+    return !RB.isMask() ? aliasRM(RB, RA) : aliasMM(RA, RB);
   }
 
+  // Returns the set of aliased physical registers or register masks.
+  // The returned set does not contain register units.
   std::set<RegisterId> getAliasSet(RegisterId Reg) const;
 
   RegisterRef getRefForUnit(uint32_t U) const {
@@ -131,6 +150,12 @@ struct PhysicalRegisterInfo {
   RegisterRef mapTo(RegisterRef RR, unsigned R) const;
   const TargetRegisterInfo &getTRI() const { return TRI; }
 
+  bool equal_to(RegisterRef A, RegisterRef B) const;
+  bool less(RegisterRef A, RegisterRef B) const;
+
+  void print(raw_ostream &OS, RegisterRef A) const;
+  void print(raw_ostream &OS, const RegisterAggr &A) const;
+
 private:
   struct RegInfo {
     const TargetRegisterClass *RegClass = nullptr;
@@ -168,6 +193,8 @@ struct RegisterAggr {
   bool hasAliasOf(RegisterRef RR) const;
   bool hasCoverOf(RegisterRef RR) const;
 
+  const PhysicalRegisterInfo &getPRI() const { return PRI; }
+
   bool operator==(const RegisterAggr &A) const {
     return DenseMapInfo<BitVector>::isEqual(Units, A.Units);
   }
@@ -190,9 +217,7 @@ struct RegisterAggr {
 
   size_t hash() const { return DenseMapInfo<BitVector>::getHashValue(Units); }
 
-  void print(raw_ostream &OS) const;
-
-  struct rr_iterator {
+  struct ref_iterator {
     using MapType = std::map<RegisterId, LaneBitmask>;
 
   private:
@@ -202,35 +227,39 @@ struct RegisterAggr {
     const RegisterAggr *Owner;
 
   public:
-    rr_iterator(const RegisterAggr &RG, bool End);
+    ref_iterator(const RegisterAggr &RG, bool End);
 
     RegisterRef operator*() const {
       return RegisterRef(Pos->first, Pos->second);
     }
 
-    rr_iterator &operator++() {
+    ref_iterator &operator++() {
       ++Pos;
       ++Index;
       return *this;
     }
 
-    bool operator==(const rr_iterator &I) const {
+    bool operator==(const ref_iterator &I) const {
       assert(Owner == I.Owner);
       (void)Owner;
       return Index == I.Index;
     }
 
-    bool operator!=(const rr_iterator &I) const { return !(*this == I); }
+    bool operator!=(const ref_iterator &I) const { return !(*this == I); }
   };
 
-  rr_iterator rr_begin() const { return rr_iterator(*this, false); }
-  rr_iterator rr_end() const { return rr_iterator(*this, true); }
+  ref_iterator ref_begin() const { return ref_iterator(*this, false); }
+  ref_iterator ref_end() const { return ref_iterator(*this, true); }
+
+  using unit_iterator = typename BitVector::const_set_bits_iterator;
+  unit_iterator unit_begin() const { return Units.set_bits_begin(); }
+  unit_iterator unit_end() const { return Units.set_bits_end(); }
 
-  iterator_range<rr_iterator> refs() {
-    return make_range(rr_begin(), rr_end());
+  iterator_range<ref_iterator> refs() const {
+    return make_range(ref_begin(), ref_end());
   }
-  iterator_range<rr_iterator> refs() const {
-    return make_range(rr_begin(), rr_end());
+  iterator_range<unit_iterator> units() const {
+    return make_range(unit_begin(), unit_end());
   }
 
 private:
@@ -263,34 +292,65 @@ template <typename KeyType> struct RegisterAggrMap {
   using value_type = typename decltype(Map)::value_type;
 };
 
-// Optionally print the lane mask, if it is not ~0.
-struct PrintLaneMaskOpt {
-  PrintLaneMaskOpt(LaneBitmask M) : Mask(M) {}
+raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A);
+
+// Print the lane mask in a short form (or not at all if all bits are set).
+struct PrintLaneMaskShort {
+  PrintLaneMaskShort(LaneBitmask M) : Mask(M) {}
   LaneBitmask Mask;
 };
-raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskOpt &P);
+raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P);
 
-raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A);
 } // end namespace rdf
-
 } // end namespace llvm
 
 namespace std {
+
 template <> struct hash<llvm::rdf::RegisterRef> {
   size_t operator()(llvm::rdf::RegisterRef A) const { //
     return A.hash();
   }
 };
+
 template <> struct hash<llvm::rdf::RegisterAggr> {
   size_t operator()(const llvm::rdf::RegisterAggr &A) const { //
     return A.hash();
   }
 };
+
+template <> struct equal_to<llvm::rdf::RegisterRef> {
+  constexpr equal_to(const llvm::rdf::PhysicalRegisterInfo &pri) : PRI(pri) {}
+
+  bool operator()(llvm::rdf::RegisterRef A, llvm::rdf::RegisterRef B) const {
+    return PRI.equal_to(A, B);
+  }
+
+private:
+  const llvm::rdf::PhysicalRegisterInfo &PRI;
+};
+
 template <> struct equal_to<llvm::rdf::RegisterAggr> {
   bool operator()(const llvm::rdf::RegisterAggr &A,
                   const llvm::rdf::RegisterAggr &B) const {
     return A == B;
   }
 };
+
+template <> struct less<llvm::rdf::RegisterRef> {
+  constexpr less(const llvm::rdf::PhysicalRegisterInfo &pri) : PRI(pri) {}
+
+  bool operator()(llvm::rdf::RegisterRef A, llvm::rdf::RegisterRef B) const {
+    return PRI.less(A, B);
+  }
+
+private:
+  const llvm::rdf::PhysicalRegisterInfo &PRI;
+};
+
 } // namespace std
+
+namespace llvm::rdf {
+using RegisterSet = std::set<RegisterRef, std::less<RegisterRef>>;
+} // namespace llvm::rdf
+
 #endif // LLVM_CODEGEN_RDFREGISTERS_H

diff  --git a/llvm/lib/CodeGen/RDFGraph.cpp b/llvm/lib/CodeGen/RDFGraph.cpp
index db2589cc2d18e..8a9f131e9d186 100644
--- a/llvm/lib/CodeGen/RDFGraph.cpp
+++ b/llvm/lib/CodeGen/RDFGraph.cpp
@@ -46,19 +46,8 @@ using namespace rdf;
 namespace llvm {
 namespace rdf {
 
-raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskOpt &P) {
-  if (!P.Mask.all())
-    OS << ':' << PrintLaneMask(P.Mask);
-  return OS;
-}
-
 raw_ostream &operator<<(raw_ostream &OS, const Print<RegisterRef> &P) {
-  auto &TRI = P.G.getTRI();
-  if (P.Obj.Reg > 0 && P.Obj.Reg < TRI.getNumRegs())
-    OS << TRI.getName(P.Obj.Reg);
-  else
-    OS << '#' << P.Obj.Reg;
-  OS << PrintLaneMaskOpt(P.Obj.Mask);
+  P.G.getPRI().print(OS, P.Obj);
   return OS;
 }
 
@@ -327,7 +316,7 @@ raw_ostream &operator<<(raw_ostream &OS, const Print<RegisterSet> &P) {
 }
 
 raw_ostream &operator<<(raw_ostream &OS, const Print<RegisterAggr> &P) {
-  P.Obj.print(OS);
+  OS << P.Obj;
   return OS;
 }
 
@@ -906,7 +895,7 @@ void DataFlowGraph::build(unsigned Options) {
   NodeList Blocks = Func.Addr->members(*this);
 
   // Collect information about block references.
-  RegisterSet AllRefs;
+  RegisterSet AllRefs(getPRI());
   for (NodeAddr<BlockNode *> BA : Blocks)
     for (NodeAddr<InstrNode *> IA : BA.Addr->members(*this))
       for (NodeAddr<RefNode *> RA : IA.Addr->members(*this))
@@ -982,8 +971,7 @@ void DataFlowGraph::build(unsigned Options) {
 }
 
 RegisterRef DataFlowGraph::makeRegRef(unsigned Reg, unsigned Sub) const {
-  assert(PhysicalRegisterInfo::isRegMaskId(Reg) ||
-         Register::isPhysicalRegister(Reg));
+  assert(RegisterRef::isRegId(Reg) || RegisterRef::isMaskId(Reg));
   assert(Reg != 0);
   if (Sub != 0)
     Reg = TRI.getSubReg(Reg, Sub);
@@ -994,7 +982,8 @@ RegisterRef DataFlowGraph::makeRegRef(const MachineOperand &Op) const {
   assert(Op.isReg() || Op.isRegMask());
   if (Op.isReg())
     return makeRegRef(Op.getReg(), Op.getSubReg());
-  return RegisterRef(PRI.getRegMaskId(Op.getRegMask()), LaneBitmask::getAll());
+  return RegisterRef(getPRI().getRegMaskId(Op.getRegMask()),
+                     LaneBitmask::getAll());
 }
 
 // For each stack in the map DefM, push the delimiter for block B on it.
@@ -1060,7 +1049,7 @@ void DataFlowGraph::pushClobbers(NodeAddr<InstrNode *> IA, DefStackMap &DefM) {
     // The def stack traversal in linkNodeUp will check the exact aliasing.
     DefM[RR.Reg].push(DA);
     Defined.insert(RR.Reg);
-    for (RegisterId A : PRI.getAliasSet(RR.Reg)) {
+    for (RegisterId A : getPRI().getAliasSet(RR.Reg)) {
       // Check that we don't push the same def twice.
       assert(A != RR.Reg);
       if (!Defined.count(A))
@@ -1115,7 +1104,7 @@ void DataFlowGraph::pushDefs(NodeAddr<InstrNode *> IA, DefStackMap &DefM) {
     // Push the definition on the stack for the register and all aliases.
     // The def stack traversal in linkNodeUp will check the exact aliasing.
     DefM[RR.Reg].push(DA);
-    for (RegisterId A : PRI.getAliasSet(RR.Reg)) {
+    for (RegisterId A : getPRI().getAliasSet(RR.Reg)) {
       // Check that we don't push the same def twice.
       assert(A != RR.Reg);
       DefM[A].push(DA);
@@ -1162,8 +1151,10 @@ DataFlowGraph::getNextRelated(NodeAddr<InstrNode *> IA,
   auto Related = [this, RA](NodeAddr<RefNode *> TA) -> bool {
     if (TA.Addr->getKind() != RA.Addr->getKind())
       return false;
-    if (TA.Addr->getRegRef(*this) != RA.Addr->getRegRef(*this))
+    if (!getPRI().equal_to(TA.Addr->getRegRef(*this),
+                           RA.Addr->getRegRef(*this))) {
       return false;
+    }
     return true;
   };
   auto RelatedStmt = [&Related, RA](NodeAddr<RefNode *> TA) -> bool {
@@ -1276,7 +1267,7 @@ void DataFlowGraph::buildStmt(NodeAddr<BlockNode *> BA, MachineInstr &In) {
       if (Op.getReg() == 0 || Op.isUndef())
         continue;
       RegisterRef UR = makeRegRef(Op);
-      if (PRI.alias(DR, UR))
+      if (getPRI().alias(DR, UR))
         return false;
     }
     return true;
@@ -1514,7 +1505,7 @@ void DataFlowGraph::linkRefUp(NodeAddr<InstrNode *> IA, NodeAddr<T> TA,
   NodeAddr<T> TAP;
 
   // References from the def stack that have been examined so far.
-  RegisterAggr Defs(PRI);
+  RegisterAggr Defs(getPRI());
 
   for (auto I = DS.top(), E = DS.bottom(); I != E; I.down()) {
     RegisterRef QR = I->Addr->getRegRef(*this);
@@ -1554,7 +1545,7 @@ template <typename Predicate>
 void DataFlowGraph::linkStmtRefs(DefStackMap &DefM, NodeAddr<StmtNode *> SA,
                                  Predicate P) {
 #ifndef NDEBUG
-  RegisterSet Defs;
+  RegisterSet Defs(getPRI());
 #endif
 
   // Link all nodes (upwards in the data-flow) with their reaching defs.

diff  --git a/llvm/lib/CodeGen/RDFLiveness.cpp b/llvm/lib/CodeGen/RDFLiveness.cpp
index 31ab7f3ed687d..e404b02054df0 100644
--- a/llvm/lib/CodeGen/RDFLiveness.cpp
+++ b/llvm/lib/CodeGen/RDFLiveness.cpp
@@ -65,7 +65,7 @@ raw_ostream &operator<<(raw_ostream &OS, const Print<Liveness::RefMap> &P) {
   for (const auto &I : P.Obj) {
     OS << ' ' << printReg(I.first, &P.G.getTRI()) << '{';
     for (auto J = I.second.begin(), E = I.second.end(); J != E;) {
-      OS << Print(J->first, P.G) << PrintLaneMaskOpt(J->second);
+      OS << Print(J->first, P.G) << PrintLaneMaskShort(J->second);
       if (++J != E)
         OS << ',';
     }
@@ -659,6 +659,8 @@ void Liveness::computePhiInfo() {
   // The operation "clearIn" can be expensive. For a given set of intervening
   // defs, cache the result of subtracting these defs from a given register
   // ref.
+  using RefHash = std::hash<RegisterRef>;
+  using RefEqual = std::equal_to<RegisterRef>;
   using SubMap = std::unordered_map<RegisterRef, RegisterRef>;
   std::unordered_map<RegisterAggr, SubMap> Subs;
   auto ClearIn = [](RegisterRef RR, const RegisterAggr &Mid, SubMap &SM) {
@@ -690,7 +692,10 @@ void Liveness::computePhiInfo() {
 
         if (MidDefs.hasCoverOf(UR))
           continue;
-        SubMap &SM = Subs[MidDefs];
+        if (Subs.find(MidDefs) == Subs.end()) {
+          Subs.insert({MidDefs, SubMap(1, RefHash(), RefEqual(PRI))});
+        }
+        SubMap &SM = Subs.at(MidDefs);
 
         // General algorithm:
         //   for each (R,U) : U is use node of R, U is reached by PA
@@ -873,7 +878,7 @@ void Liveness::computeLiveIns() {
       std::vector<RegisterRef> LV;
       for (const MachineBasicBlock::RegisterMaskPair &LI : B.liveins())
         LV.push_back(RegisterRef(LI.PhysReg, LI.LaneMask));
-      llvm::sort(LV);
+      llvm::sort(LV, std::less<RegisterRef>(PRI));
       dbgs() << printMBBReference(B) << "\t rec = {";
       for (auto I : LV)
         dbgs() << ' ' << Print(I, DFG);
@@ -883,7 +888,7 @@ void Liveness::computeLiveIns() {
       LV.clear();
       for (RegisterRef RR : LiveMap[&B].refs())
         LV.push_back(RR);
-      llvm::sort(LV);
+      llvm::sort(LV, std::less<RegisterRef>(PRI));
       dbgs() << "\tcomp = {";
       for (auto I : LV)
         dbgs() << ' ' << Print(I, DFG);

diff  --git a/llvm/lib/CodeGen/RDFRegisters.cpp b/llvm/lib/CodeGen/RDFRegisters.cpp
index 3e9a11d5397c9..0f451b6c044ed 100644
--- a/llvm/lib/CodeGen/RDFRegisters.cpp
+++ b/llvm/lib/CodeGen/RDFRegisters.cpp
@@ -15,6 +15,7 @@
 #include "llvm/MC/LaneBitmask.h"
 #include "llvm/MC/MCRegisterInfo.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Format.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -104,10 +105,10 @@ PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
 }
 
 std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
-  // Do not include RR in the alias set.
+  // Do not include Reg in the alias set.
   std::set<RegisterId> AS;
-  assert(isRegMaskId(Reg) || Register::isPhysicalRegister(Reg));
-  if (isRegMaskId(Reg)) {
+  assert(!RegisterRef::isUnitId(Reg) && "No units allowed");
+  if (RegisterRef::isMaskId(Reg)) {
     // XXX SLOW
     const uint32_t *MB = getRegMaskBits(Reg);
     for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
@@ -123,6 +124,7 @@ std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
     return AS;
   }
 
+  assert(RegisterRef::isRegId(Reg));
   for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
     AS.insert(*AI);
   for (const uint32_t *RM : RegMasks) {
@@ -134,8 +136,7 @@ std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
 }
 
 bool PhysicalRegisterInfo::aliasRR(RegisterRef RA, RegisterRef RB) const {
-  assert(Register::isPhysicalRegister(RA.Reg));
-  assert(Register::isPhysicalRegister(RB.Reg));
+  assert(RA.isReg() && RB.isReg());
 
   MCRegUnitMaskIterator UMA(RA.Reg, &TRI);
   MCRegUnitMaskIterator UMB(RB.Reg, &TRI);
@@ -165,7 +166,7 @@ bool PhysicalRegisterInfo::aliasRR(RegisterRef RA, RegisterRef RB) const {
 }
 
 bool PhysicalRegisterInfo::aliasRM(RegisterRef RR, RegisterRef RM) const {
-  assert(Register::isPhysicalRegister(RR.Reg) && isRegMaskId(RM.Reg));
+  assert(RR.isReg() && RM.isMask());
   const uint32_t *MB = getRegMaskBits(RM.Reg);
   bool Preserved = MB[RR.Reg / 32] & (1u << (RR.Reg % 32));
   // If the lane mask information is "full", e.g. when the given lane mask
@@ -200,7 +201,7 @@ bool PhysicalRegisterInfo::aliasRM(RegisterRef RR, RegisterRef RM) const {
 }
 
 bool PhysicalRegisterInfo::aliasMM(RegisterRef RM, RegisterRef RN) const {
-  assert(isRegMaskId(RM.Reg) && isRegMaskId(RN.Reg));
+  assert(RM.isMask() && RN.isMask());
   unsigned NumRegs = TRI.getNumRegs();
   const uint32_t *BM = getRegMaskBits(RM.Reg);
   const uint32_t *BN = getRegMaskBits(RN.Reg);
@@ -242,8 +243,118 @@ RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
   llvm_unreachable("Invalid arguments: unrelated registers?");
 }
 
+bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
+  if (!A.isReg() || !B.isReg()) {
+    // For non-regs, or comparing reg and non-reg, use only the Reg member.
+    return A.Reg == B.Reg;
+  }
+
+  if (A.Reg == B.Reg)
+    return A.Mask == B.Mask;
+
+  // Compare reg units lexicographically.
+  MCRegUnitMaskIterator AI(A.Reg, &getTRI());
+  MCRegUnitMaskIterator BI(B.Reg, &getTRI());
+  while (AI.isValid() && BI.isValid()) {
+    auto [AReg, AMask] = *AI;
+    auto [BReg, BMask] = *BI;
+
+    // Lane masks are "none" for units that don't correspond to subregs
+    // e.g. a single unit in a leaf register, or aliased unit.
+    if (AMask.none())
+      AMask = LaneBitmask::getAll();
+    if (BMask.none())
+      BMask = LaneBitmask::getAll();
+
+    // If both iterators point to a unit contained in both A and B, then
+    // compare the units.
+    if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
+      if (AReg != BReg)
+        return false;
+      // Units are equal, move on to the next ones.
+      ++AI;
+      ++BI;
+      continue;
+    }
+
+    if ((AMask & A.Mask).none())
+      ++AI;
+    if ((BMask & B.Mask).none())
+      ++BI;
+  }
+  // One or both have reached the end.
+  return static_cast<int>(AI.isValid()) == static_cast<int>(BI.isValid());
+}
+
+bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
+  if (!A.isReg() || !B.isReg()) {
+    // For non-regs, or comparing reg and non-reg, use only the Reg member.
+    return A.Reg < B.Reg;
+  }
+
+  if (A.Reg == B.Reg)
+    return A.Mask < B.Mask;
+  if (A.Mask == B.Mask)
+    return A.Reg < B.Reg;
+
+  // Compare reg units lexicographically.
+  llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
+  llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
+  while (AI.isValid() && BI.isValid()) {
+    auto [AReg, AMask] = *AI;
+    auto [BReg, BMask] = *BI;
+
+    // Lane masks are "none" for units that don't correspond to subregs
+    // e.g. a single unit in a leaf register, or aliased unit.
+    if (AMask.none())
+      AMask = LaneBitmask::getAll();
+    if (BMask.none())
+      BMask = LaneBitmask::getAll();
+
+    // If both iterators point to a unit contained in both A and B, then
+    // compare the units.
+    if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
+      if (AReg != BReg)
+        return AReg < BReg;
+      // Units are equal, move on to the next ones.
+      ++AI;
+      ++BI;
+      continue;
+    }
+
+    if ((AMask & A.Mask).none())
+      ++AI;
+    if ((BMask & B.Mask).none())
+      ++BI;
+  }
+  // One or both have reached the end: assume invalid < valid.
+  return static_cast<int>(AI.isValid()) < static_cast<int>(BI.isValid());
+}
+
+void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
+  if (A.Reg == 0 || A.isReg()) {
+    if (0 < A.Reg && A.Reg < TRI.getNumRegs())
+      OS << TRI.getName(A.Reg);
+    else
+      OS << printReg(A.Reg, &TRI);
+    OS << PrintLaneMaskShort(A.Mask);
+  } else if (A.isUnit()) {
+    OS << printRegUnit(A.toRegUnit(A.Reg), &TRI);
+  } else {
+    assert(A.isMask());
+    OS << '#' << format("%08x", A.Reg);
+  }
+}
+
+void PhysicalRegisterInfo::print(raw_ostream &OS, const RegisterAggr &A) const {
+  OS << '{';
+  for (unsigned U : A.units())
+    OS << ' ' << printRegUnit(U, &TRI);
+  OS << " }";
+}
+
 bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
-  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg))
+  if (RR.isMask())
     return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
 
   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
@@ -256,7 +367,7 @@ bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
 }
 
 bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
-  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
+  if (RR.isMask()) {
     BitVector T(PRI.getMaskUnits(RR.Reg));
     return T.reset(Units).none();
   }
@@ -271,7 +382,7 @@ bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
 }
 
 RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
-  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
+  if (RR.isMask()) {
     Units |= PRI.getMaskUnits(RR.Reg);
     return *this;
   }
@@ -357,14 +468,7 @@ RegisterRef RegisterAggr::makeRegRef() const {
   return RegisterRef(F, M);
 }
 
-void RegisterAggr::print(raw_ostream &OS) const {
-  OS << '{';
-  for (int U = Units.find_first(); U >= 0; U = Units.find_next(U))
-    OS << ' ' << printRegUnit(U, &PRI.getTRI());
-  OS << " }";
-}
-
-RegisterAggr::rr_iterator::rr_iterator(const RegisterAggr &RG, bool End)
+RegisterAggr::ref_iterator::ref_iterator(const RegisterAggr &RG, bool End)
     : Owner(&RG) {
   for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
     RegisterRef R = RG.PRI.getRefForUnit(U);
@@ -375,6 +479,20 @@ RegisterAggr::rr_iterator::rr_iterator(const RegisterAggr &RG, bool End)
 }
 
 raw_ostream &rdf::operator<<(raw_ostream &OS, const RegisterAggr &A) {
-  A.print(OS);
+  A.getPRI().print(OS, A);
   return OS;
 }
+
+raw_ostream &rdf::operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
+  if (P.Mask.all())
+    return OS;
+  if (P.Mask.none())
+    return OS << ":*none*";
+
+  LaneBitmask::Type Val = P.Mask.getAsInteger();
+  if ((Val & 0xffff) == Val)
+    return OS << ':' << format("%04llX", Val);
+  if ((Val & 0xffffffff) == Val)
+    return OS << ':' << format("%08llX", Val);
+  return OS << ':' << PrintLaneMask(P.Mask);
+}

diff  --git a/llvm/lib/Target/Hexagon/RDFCopy.cpp b/llvm/lib/Target/Hexagon/RDFCopy.cpp
index e24f66de653d3..c26811e9cd05d 100644
--- a/llvm/lib/Target/Hexagon/RDFCopy.cpp
+++ b/llvm/lib/Target/Hexagon/RDFCopy.cpp
@@ -76,7 +76,7 @@ void CopyPropagation::recordCopy(NodeAddr<StmtNode*> SA, EqualityMap &EM) {
 
 
 void CopyPropagation::updateMap(NodeAddr<InstrNode*> IA) {
-  RegisterSet RRs;
+  RegisterSet RRs(DFG.getPRI());
   for (NodeAddr<RefNode*> RA : IA.Addr->members(DFG))
     RRs.insert(RA.Addr->getRegRef(DFG));
   bool Common = false;
@@ -107,7 +107,7 @@ bool CopyPropagation::scanBlock(MachineBasicBlock *B) {
   for (NodeAddr<InstrNode*> IA : BA.Addr->members(DFG)) {
     if (DFG.IsCode<NodeAttrs::Stmt>(IA)) {
       NodeAddr<StmtNode*> SA = IA;
-      EqualityMap EM;
+      EqualityMap EM(std::less<RegisterRef>(DFG.getPRI()));
       if (interpretAsCopy(SA.Addr->getCode(), EM))
         recordCopy(SA, EM);
     }
@@ -132,9 +132,11 @@ bool CopyPropagation::run() {
     for (NodeId I : Copies) {
       dbgs() << "Instr: " << *DFG.addr<StmtNode*>(I).Addr->getCode();
       dbgs() << "   eq: {";
-      for (auto J : CopyMap[I])
-        dbgs() << ' ' << Print<RegisterRef>(J.first, DFG) << '='
-               << Print<RegisterRef>(J.second, DFG);
+      if (CopyMap.count(I)) {
+        for (auto J : CopyMap.at(I))
+          dbgs() << ' ' << Print<RegisterRef>(J.first, DFG) << '='
+                 << Print<RegisterRef>(J.second, DFG);
+      }
       dbgs() << " }\n";
     }
     dbgs() << "\nRDef map:\n";
@@ -164,6 +166,8 @@ bool CopyPropagation::run() {
     return 0;
   };
 
+  const PhysicalRegisterInfo &PRI = DFG.getPRI();
+
   for (NodeId C : Copies) {
 #ifndef NDEBUG
     if (HasLimit && CpCount >= CpLimit)
@@ -181,7 +185,7 @@ bool CopyPropagation::run() {
       if (FR == EM.end())
         continue;
       RegisterRef SR = FR->second;
-      if (DR == SR)
+      if (PRI.equal_to(DR, SR))
         continue;
 
       auto &RDefSR = RDefMap[SR];
@@ -193,7 +197,7 @@ bool CopyPropagation::run() {
         uint16_t F = UA.Addr->getFlags();
         if ((F & NodeAttrs::PhiRef) || (F & NodeAttrs::Fixed))
           continue;
-        if (UA.Addr->getRegRef(DFG) != DR)
+        if (!PRI.equal_to(UA.Addr->getRegRef(DFG), DR))
           continue;
 
         NodeAddr<InstrNode*> IA = UA.Addr->getOwner(DFG);
@@ -233,7 +237,7 @@ bool CopyPropagation::run() {
           // Update the EM map in the copy's entry.
           auto &M = FC->second;
           for (auto &J : M) {
-            if (J.second != DR)
+            if (!PRI.equal_to(J.second, DR))
               continue;
             J.second = SR;
             break;

diff  --git a/llvm/lib/Target/Hexagon/RDFCopy.h b/llvm/lib/Target/Hexagon/RDFCopy.h
index 8bca374a52887..e4fb89892831d 100644
--- a/llvm/lib/Target/Hexagon/RDFCopy.h
+++ b/llvm/lib/Target/Hexagon/RDFCopy.h
@@ -25,7 +25,8 @@ class MachineInstr;
 namespace rdf {
 
   struct CopyPropagation {
-    CopyPropagation(DataFlowGraph &dfg) : MDT(dfg.getDT()), DFG(dfg) {}
+    CopyPropagation(DataFlowGraph &dfg) : MDT(dfg.getDT()), DFG(dfg),
+        RDefMap(std::less<RegisterRef>(DFG.getPRI())) {}
 
     virtual ~CopyPropagation() = default;
 


        


More information about the llvm-commits mailing list