[llvm] [TRI][RISCV] Add methods to get common register class of two registers (PR #118435)

Pengcheng Wang via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 22 20:25:34 PST 2024


https://github.com/wangpc-pp updated https://github.com/llvm/llvm-project/pull/118435

>From dfec9e3495063c277b458f67383d1ef3a6a992a9 Mon Sep 17 00:00:00 2001
From: Wang Pengcheng <wangpengcheng.pp at bytedance.com>
Date: Tue, 3 Dec 2024 14:50:18 +0800
Subject: [PATCH 1/4] [TRI][RISCV] Add methods to get common register class of
 two registers

Here we add two methods `getCommonMinimalPhysRegClass` and a LLT
version `getCommonMinimalPhysRegClassLLT`, which return the most
sub register class of the right type that contains these two input
registers.

We don't overload the `getMinimalPhysRegClass` as there will be
ambiguities.

We use it to simplify some code in RISC-V target.
---
 .../include/llvm/CodeGen/TargetRegisterInfo.h | 15 ++++++++
 llvm/lib/CodeGen/TargetRegisterInfo.cpp       | 37 +++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 16 +++-----
 3 files changed, 57 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
index 374f9f2e7f5696..f4bf74c8caa5b8 100644
--- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
@@ -347,6 +347,13 @@ class TargetRegisterInfo : public MCRegisterInfo {
   const TargetRegisterClass *getMinimalPhysRegClass(MCRegister Reg,
                                                     MVT VT = MVT::Other) const;
 
+  /// Returns the common Register Class of two physical registers of the given
+  /// type, picking the most sub register class of the right type that contains
+  /// these two physregs.
+  const TargetRegisterClass *
+  getCommonMinimalPhysRegClass(MCRegister Reg1, MCRegister Reg2,
+                               MVT VT = MVT::Other) const;
+
   /// Returns the Register Class of a physical register of the given type,
   /// picking the most sub register class of the right type that contains this
   /// physreg. If there is no register class compatible with the given type,
@@ -354,6 +361,14 @@ class TargetRegisterInfo : public MCRegisterInfo {
   const TargetRegisterClass *getMinimalPhysRegClassLLT(MCRegister Reg,
                                                        LLT Ty = LLT()) const;
 
+  /// Returns the common Register Class of two physical registers of the given
+  /// type, picking the most sub register class of the right type that contains
+  /// these two physregs. If there is no register class compatible with the
+  /// given type, returns nullptr.
+  const TargetRegisterClass *
+  getCommonMinimalPhysRegClassLLT(MCRegister Reg1, MCRegister Reg2,
+                                  LLT Ty = LLT()) const;
+
   /// Return the maximal subclass of the given register class that is
   /// allocatable or NULL.
   const TargetRegisterClass *
diff --git a/llvm/lib/CodeGen/TargetRegisterInfo.cpp b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
index 032f1a33e75c43..4eff8bdbf744be 100644
--- a/llvm/lib/CodeGen/TargetRegisterInfo.cpp
+++ b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
@@ -222,6 +222,25 @@ TargetRegisterInfo::getMinimalPhysRegClass(MCRegister reg, MVT VT) const {
   return BestRC;
 }
 
+const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClass(
+    MCRegister Reg1, MCRegister Reg2, MVT VT) const {
+  assert(Register::isPhysicalRegister(Reg1) &&
+         Register::isPhysicalRegister(Reg2) &&
+         "Reg1/Reg2 must be a physical register");
+
+  // Pick the most sub register class of the right type that contains
+  // this physreg.
+  const TargetRegisterClass *BestRC = nullptr;
+  for (const TargetRegisterClass *RC : regclasses()) {
+    if ((VT == MVT::Other || isTypeLegalForClass(*RC, VT)) &&
+        RC->contains(Reg1, Reg2) && (!BestRC || BestRC->hasSubClass(RC)))
+      BestRC = RC;
+  }
+
+  assert(BestRC && "Couldn't find the register class");
+  return BestRC;
+}
+
 const TargetRegisterClass *
 TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister reg, LLT Ty) const {
   assert(Register::isPhysicalRegister(reg) &&
@@ -239,6 +258,24 @@ TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister reg, LLT Ty) const {
   return BestRC;
 }
 
+const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClassLLT(
+    MCRegister Reg1, MCRegister Reg2, LLT Ty) const {
+  assert(Register::isPhysicalRegister(Reg1) &&
+         Register::isPhysicalRegister(Reg2) &&
+         "Reg1/Reg2 must be a physical register");
+
+  // Pick the most sub register class of the right type that contains
+  // this physreg.
+  const TargetRegisterClass *BestRC = nullptr;
+  for (const TargetRegisterClass *RC : regclasses()) {
+    if ((!Ty.isValid() || isTypeLegalForClass(*RC, Ty)) &&
+        RC->contains(Reg1, Reg2) && (!BestRC || BestRC->hasSubClass(RC)))
+      BestRC = RC;
+  }
+
+  return BestRC;
+}
+
 /// getAllocatableSetForRC - Toggle the bits that represent allocatable
 /// registers for the specific register class.
 static void getAllocatableSetForRC(const MachineFunction &MF,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index ee720ceb22b00f..abc906aadfbaed 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -564,17 +564,11 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   }
 
   // VR->VR copies.
-  static const TargetRegisterClass *RVVRegClasses[] = {
-      &RISCV::VRRegClass,     &RISCV::VRM2RegClass,   &RISCV::VRM4RegClass,
-      &RISCV::VRM8RegClass,   &RISCV::VRN2M1RegClass, &RISCV::VRN2M2RegClass,
-      &RISCV::VRN2M4RegClass, &RISCV::VRN3M1RegClass, &RISCV::VRN3M2RegClass,
-      &RISCV::VRN4M1RegClass, &RISCV::VRN4M2RegClass, &RISCV::VRN5M1RegClass,
-      &RISCV::VRN6M1RegClass, &RISCV::VRN7M1RegClass, &RISCV::VRN8M1RegClass};
-  for (const auto &RegClass : RVVRegClasses) {
-    if (RegClass->contains(DstReg, SrcReg)) {
-      copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
-      return;
-    }
+  const TargetRegisterClass *RegClass =
+      TRI->getCommonMinimalPhysRegClass(SrcReg, DstReg);
+  if (RISCVRegisterInfo::isRVVRegClass(RegClass)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
+    return;
   }
 
   llvm_unreachable("Impossible reg-to-reg copy");

>From bf69e06b6522a1157e004b9eb90c2a7e4743a0a3 Mon Sep 17 00:00:00 2001
From: Wang Pengcheng <wangpengcheng.pp at bytedance.com>
Date: Thu, 5 Dec 2024 20:52:40 +0800
Subject: [PATCH 2/4] Use hasSuperClassEq in FindRegWithEncoding

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index abc906aadfbaed..eb26331e2d5449 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -390,7 +390,7 @@ void RISCVInstrInfo::copyPhysRegVector(
   auto FindRegWithEncoding = [TRI](const TargetRegisterClass &RegClass,
                                    uint16_t Encoding) {
     MCRegister Reg = RISCV::V0 + Encoding;
-    if (&RegClass == &RISCV::VRRegClass)
+    if (RegClass.hasSuperClassEq(&RISCV::VRRegClass))
       return Reg;
     return TRI->getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
   };

>From 0d573512a2ee01305b7c716f1a2003dd5e5598ae Mon Sep 17 00:00:00 2001
From: Wang Pengcheng <wangpengcheng.pp at bytedance.com>
Date: Fri, 6 Dec 2024 14:59:53 +0800
Subject: [PATCH 3/4] Get LMUL from TSFlags

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index eb26331e2d5449..f24940795e433f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -390,7 +390,7 @@ void RISCVInstrInfo::copyPhysRegVector(
   auto FindRegWithEncoding = [TRI](const TargetRegisterClass &RegClass,
                                    uint16_t Encoding) {
     MCRegister Reg = RISCV::V0 + Encoding;
-    if (RegClass.hasSuperClassEq(&RISCV::VRRegClass))
+    if (RISCVRI::getLMul(RegClass.TSFlags) == RISCVII::LMUL_1)
       return Reg;
     return TRI->getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
   };

>From ddc747a015e8ee9ac062e1721523b6e43234feb4 Mon Sep 17 00:00:00 2001
From: Wang Pengcheng <wangpengcheng.pp at bytedance.com>
Date: Fri, 6 Dec 2024 18:13:40 +0800
Subject: [PATCH 4/4] Use template functions

---
 llvm/lib/CodeGen/TargetRegisterInfo.cpp | 88 +++++++++++++------------
 1 file changed, 46 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/CodeGen/TargetRegisterInfo.cpp b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
index 4eff8bdbf744be..af62623ece6ab6 100644
--- a/llvm/lib/CodeGen/TargetRegisterInfo.cpp
+++ b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
@@ -201,79 +201,83 @@ TargetRegisterInfo::getAllocatableClass(const TargetRegisterClass *RC) const {
   return nullptr;
 }
 
-/// getMinimalPhysRegClass - Returns the Register Class of a physical
-/// register of the given type, picking the most sub register class of
-/// the right type that contains this physreg.
-const TargetRegisterClass *
-TargetRegisterInfo::getMinimalPhysRegClass(MCRegister reg, MVT VT) const {
-  assert(Register::isPhysicalRegister(reg) &&
+template <typename TypeT>
+static const TargetRegisterClass *
+getMinimalPhysRegClass(const TargetRegisterInfo *TRI, MCRegister Reg,
+                       TypeT Ty) {
+  static_assert(std::is_same_v<TypeT, MVT> || std::is_same_v<TypeT, LLT>);
+  assert(Register::isPhysicalRegister(Reg) &&
          "reg must be a physical register");
 
+  bool IsDefault = [&]() {
+    if constexpr (std::is_same_v<TypeT, MVT>)
+      return Ty == MVT::Other;
+    else
+      return !Ty.isValid();
+  }();
+
   // Pick the most sub register class of the right type that contains
   // this physreg.
-  const TargetRegisterClass* BestRC = nullptr;
-  for (const TargetRegisterClass* RC : regclasses()) {
-    if ((VT == MVT::Other || isTypeLegalForClass(*RC, VT)) &&
-        RC->contains(reg) && (!BestRC || BestRC->hasSubClass(RC)))
+  const TargetRegisterClass *BestRC = nullptr;
+  for (const TargetRegisterClass *RC : TRI->regclasses()) {
+    if ((IsDefault || TRI->isTypeLegalForClass(*RC, Ty)) && RC->contains(Reg) &&
+        (!BestRC || BestRC->hasSubClass(RC)))
       BestRC = RC;
   }
 
-  assert(BestRC && "Couldn't find the register class");
+  if constexpr (std::is_same_v<TypeT, MVT>)
+    assert(BestRC && "Couldn't find the register class");
   return BestRC;
 }
 
-const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClass(
-    MCRegister Reg1, MCRegister Reg2, MVT VT) const {
+template <typename TypeT>
+static const TargetRegisterClass *
+getCommonMinimalPhysRegClass(const TargetRegisterInfo *TRI, MCRegister Reg1,
+                             MCRegister Reg2, TypeT Ty) {
+  static_assert(std::is_same_v<TypeT, MVT> || std::is_same_v<TypeT, LLT>);
   assert(Register::isPhysicalRegister(Reg1) &&
          Register::isPhysicalRegister(Reg2) &&
          "Reg1/Reg2 must be a physical register");
 
+  bool IsDefault = [&]() {
+    if constexpr (std::is_same_v<TypeT, MVT>)
+      return Ty == MVT::Other;
+    else
+      return !Ty.isValid();
+  }();
+
   // Pick the most sub register class of the right type that contains
   // this physreg.
   const TargetRegisterClass *BestRC = nullptr;
-  for (const TargetRegisterClass *RC : regclasses()) {
-    if ((VT == MVT::Other || isTypeLegalForClass(*RC, VT)) &&
+  for (const TargetRegisterClass *RC : TRI->regclasses()) {
+    if ((IsDefault || TRI->isTypeLegalForClass(*RC, Ty)) &&
         RC->contains(Reg1, Reg2) && (!BestRC || BestRC->hasSubClass(RC)))
       BestRC = RC;
   }
 
-  assert(BestRC && "Couldn't find the register class");
+  if constexpr (std::is_same_v<TypeT, MVT>)
+    assert(BestRC && "Couldn't find the register class");
   return BestRC;
 }
 
 const TargetRegisterClass *
-TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister reg, LLT Ty) const {
-  assert(Register::isPhysicalRegister(reg) &&
-         "reg must be a physical register");
+TargetRegisterInfo::getMinimalPhysRegClass(MCRegister Reg, MVT VT) const {
+  return ::getMinimalPhysRegClass(this, Reg, VT);
+}
 
-  // Pick the most sub register class of the right type that contains
-  // this physreg.
-  const TargetRegisterClass *BestRC = nullptr;
-  for (const TargetRegisterClass *RC : regclasses()) {
-    if ((!Ty.isValid() || isTypeLegalForClass(*RC, Ty)) && RC->contains(reg) &&
-        (!BestRC || BestRC->hasSubClass(RC)))
-      BestRC = RC;
-  }
+const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClass(
+    MCRegister Reg1, MCRegister Reg2, MVT VT) const {
+  return ::getCommonMinimalPhysRegClass(this, Reg1, Reg2, VT);
+}
 
-  return BestRC;
+const TargetRegisterClass *
+TargetRegisterInfo::getMinimalPhysRegClassLLT(MCRegister Reg, LLT Ty) const {
+  return ::getMinimalPhysRegClass(this, Reg, Ty);
 }
 
 const TargetRegisterClass *TargetRegisterInfo::getCommonMinimalPhysRegClassLLT(
     MCRegister Reg1, MCRegister Reg2, LLT Ty) const {
-  assert(Register::isPhysicalRegister(Reg1) &&
-         Register::isPhysicalRegister(Reg2) &&
-         "Reg1/Reg2 must be a physical register");
-
-  // Pick the most sub register class of the right type that contains
-  // this physreg.
-  const TargetRegisterClass *BestRC = nullptr;
-  for (const TargetRegisterClass *RC : regclasses()) {
-    if ((!Ty.isValid() || isTypeLegalForClass(*RC, Ty)) &&
-        RC->contains(Reg1, Reg2) && (!BestRC || BestRC->hasSubClass(RC)))
-      BestRC = RC;
-  }
-
-  return BestRC;
+  return ::getCommonMinimalPhysRegClass(this, Reg1, Reg2, Ty);
 }
 
 /// getAllocatableSetForRC - Toggle the bits that represent allocatable



More information about the llvm-commits mailing list