[llvm] 73ddb2a - [RISCV] Store VLMul/NF into RegisterClass's TSFlags

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 7 22:35:40 PDT 2024


Author: Pengcheng Wang
Date: 2024-04-08T13:35:37+08:00
New Revision: 73ddb2a7471986a7ed600dbea14efc60f0d0db47

URL: https://github.com/llvm/llvm-project/commit/73ddb2a7471986a7ed600dbea14efc60f0d0db47
DIFF: https://github.com/llvm/llvm-project/commit/73ddb2a7471986a7ed600dbea14efc60f0d0db47.diff

LOG: [RISCV] Store VLMul/NF into RegisterClass's TSFlags

This TSFlags was introduced by https://reviews.llvm.org/D108767.

A base class of all RISCV RegisterClass is added and we store
IsVRegClass/VLMul/NF into TSFlags and add helpers to get them.

This can reduce some lines and I think there will be more usages.

Reviewers: preames, topperc

Reviewed By: topperc

Pull Request: https://github.com/llvm/llvm-project/pull/84894

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.h
    llvm/lib/Target/RISCV/RISCVRegisterInfo.h
    llvm/lib/Target/RISCV/RISCVRegisterInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index a1befaf40d09f7..153f936326a78d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -295,12 +295,13 @@ static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,
   return false;
 }
 
-void RISCVInstrInfo::copyPhysRegVector(MachineBasicBlock &MBB,
-                                       MachineBasicBlock::iterator MBBI,
-                                       const DebugLoc &DL, MCRegister DstReg,
-                                       MCRegister SrcReg, bool KillSrc,
-                                       RISCVII::VLMUL LMul, unsigned NF) const {
+void RISCVInstrInfo::copyPhysRegVector(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+    const TargetRegisterClass *RegClass) const {
   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+  RISCVII::VLMUL LMul = RISCVRI::getLMul(RegClass->TSFlags);
+  unsigned NF = RISCVRI::getNF(RegClass->TSFlags);
 
   uint16_t SrcEncoding = TRI->getEncodingValue(SrcReg);
   uint16_t DstEncoding = TRI->getEncodingValue(DstReg);
@@ -522,90 +523,17 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   }
 
   // VR->VR copies.
-  if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1);
-    return;
-  }
-
-  if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_2);
-    return;
-  }
-
-  if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_4);
-    return;
-  }
-
-  if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_8);
-    return;
-  }
-
-  if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/2);
-    return;
-  }
-
-  if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_2,
-                      /*NF=*/2);
-    return;
-  }
-
-  if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_4,
-                      /*NF=*/2);
-    return;
-  }
-
-  if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/3);
-    return;
-  }
-
-  if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_2,
-                      /*NF=*/3);
-    return;
-  }
-
-  if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/4);
-    return;
-  }
-
-  if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_2,
-                      /*NF=*/4);
-    return;
-  }
-
-  if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/5);
-    return;
-  }
-
-  if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/6);
-    return;
-  }
-
-  if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/7);
-    return;
-  }
-
-  if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
-    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCVII::LMUL_1,
-                      /*NF=*/8);
-    return;
+  static const TargetRegisterClass *RVVRegClasses[] = {
+      &RISCV::VRRegClass,     &RISCV::VRM2RegClass,   &RISCV::VRM4RegClass,
+      &RISCV::VRM8RegClass,   &RISCV::VRN2M1RegClass, &RISCV::VRN2M2RegClass,
+      &RISCV::VRN2M4RegClass, &RISCV::VRN3M1RegClass, &RISCV::VRN3M2RegClass,
+      &RISCV::VRN4M1RegClass, &RISCV::VRN4M2RegClass, &RISCV::VRN5M1RegClass,
+      &RISCV::VRN6M1RegClass, &RISCV::VRN7M1RegClass, &RISCV::VRN8M1RegClass};
+  for (const auto &RegClass : RVVRegClasses) {
+    if (RegClass->contains(DstReg, SrcReg)) {
+      copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
+      return;
+    }
   }
 
   llvm_unreachable("Impossible reg-to-reg copy");

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index dd049fca059719..3470012d1518ea 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -69,7 +69,7 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
   void copyPhysRegVector(MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI, const DebugLoc &DL,
                          MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
-                         RISCVII::VLMUL LMul, unsigned NF = 1) const;
+                         const TargetRegisterClass *RegClass) const;
   void copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg,
                    bool KillSrc) const override;

diff  --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
index 943c4f2627cf2f..7e04e9154b524e 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h
@@ -14,12 +14,45 @@
 #define LLVM_LIB_TARGET_RISCV_RISCVREGISTERINFO_H
 
 #include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
 
 #define GET_REGINFO_HEADER
 #include "RISCVGenRegisterInfo.inc"
 
 namespace llvm {
 
+namespace RISCVRI {
+enum {
+  // The IsVRegClass value of this RegisterClass.
+  IsVRegClassShift = 0,
+  IsVRegClassShiftMask = 0b1 << IsVRegClassShift,
+  // The VLMul value of this RegisterClass. This value is valid iff IsVRegClass
+  // is true.
+  VLMulShift = IsVRegClassShift + 1,
+  VLMulShiftMask = 0b111 << VLMulShift,
+
+  // The NF value of this RegisterClass. This value is valid iff IsVRegClass is
+  // true.
+  NFShift = VLMulShift + 3,
+  NFShiftMask = 0b111 << NFShift,
+};
+
+/// \returns the IsVRegClass for the register class.
+static inline bool isVRegClass(uint64_t TSFlags) {
+  return TSFlags & IsVRegClassShiftMask >> IsVRegClassShift;
+}
+
+/// \returns the LMUL for the register class.
+static inline RISCVII::VLMUL getLMul(uint64_t TSFlags) {
+  return static_cast<RISCVII::VLMUL>((TSFlags & VLMulShiftMask) >> VLMulShift);
+}
+
+/// \returns the NF for the register class.
+static inline unsigned getNF(uint64_t TSFlags) {
+  return static_cast<unsigned>((TSFlags & NFShiftMask) >> NFShift) + 1;
+}
+} // namespace RISCVRI
+
 struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
 
   RISCVRegisterInfo(unsigned HwMode);
@@ -116,30 +149,18 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
   }
 
   static bool isVRRegClass(const TargetRegisterClass *RC) {
-    return RISCV::VRRegClass.hasSubClassEq(RC) ||
-           RISCV::VRM2RegClass.hasSubClassEq(RC) ||
-           RISCV::VRM4RegClass.hasSubClassEq(RC) ||
-           RISCV::VRM8RegClass.hasSubClassEq(RC);
+    return RISCVRI::isVRegClass(RC->TSFlags) &&
+           RISCVRI::getNF(RC->TSFlags) == 1;
   }
 
   static bool isVRNRegClass(const TargetRegisterClass *RC) {
-    return RISCV::VRN2M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN2M2RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN2M4RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN3M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN3M2RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN4M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN4M2RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN5M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN6M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN7M1RegClass.hasSubClassEq(RC) ||
-           RISCV::VRN8M1RegClass.hasSubClassEq(RC);
+    return RISCVRI::isVRegClass(RC->TSFlags) && RISCVRI::getNF(RC->TSFlags) > 1;
   }
 
   static bool isRVVRegClass(const TargetRegisterClass *RC) {
-    return isVRRegClass(RC) || isVRNRegClass(RC);
+    return RISCVRI::isVRegClass(RC->TSFlags);
   }
 };
-}
+} // namespace llvm
 
 #endif

diff  --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 90c4a7193ee337..316daf2763ca1e 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -132,8 +132,21 @@ def XLenRI : RegInfoByHwMode<
       [RV32,              RV64],
       [RegInfo<32,32,32>, RegInfo<64,64,64>]>;
 
+class RISCVRegisterClass<list<ValueType> regTypes, int align, dag regList>
+    : RegisterClass<"RISCV", regTypes, align, regList> {
+  bit IsVRegClass = 0;
+  int VLMul = 1;
+  int NF = 1;
+
+  let Size = !if(IsVRegClass, !mul(VLMul, NF, 64), 0);
+
+  let TSFlags{0} = IsVRegClass;
+  let TSFlags{3-1} = !logtwo(VLMul);
+  let TSFlags{6-4} = !sub(NF, 1);
+}
+
 class GPRRegisterClass<dag regList>
-    : RegisterClass<"RISCV", [XLenVT, XLenFVT, i32], 32, regList> {
+    : RISCVRegisterClass<[XLenVT, XLenFVT, i32], 32, regList> {
   let RegInfos = XLenRI;
 }
 
@@ -229,7 +242,7 @@ let RegAltNameIndices = [ABIRegAltName] in {
 // meaning caller-save regs are listed before callee-save.
 // We start by allocating argument registers in reverse order since they are
 // compressible.
-def FPR16 : RegisterClass<"RISCV", [f16, bf16], 16, (add
+def FPR16 : RISCVRegisterClass<[f16, bf16], 16, (add
     (sequence "F%u_H", 15, 10), // fa5-fa0
     (sequence "F%u_H", 0, 7),   // ft0-f7
     (sequence "F%u_H", 16, 17), // fa6-fa7
@@ -238,7 +251,7 @@ def FPR16 : RegisterClass<"RISCV", [f16, bf16], 16, (add
     (sequence "F%u_H", 18, 27)  // fs2-fs11
 )>;
 
-def FPR32 : RegisterClass<"RISCV", [f32], 32, (add
+def FPR32 : RISCVRegisterClass<[f32], 32, (add
     (sequence "F%u_F", 15, 10),
     (sequence "F%u_F", 0, 7),
     (sequence "F%u_F", 16, 17),
@@ -247,14 +260,14 @@ def FPR32 : RegisterClass<"RISCV", [f32], 32, (add
     (sequence "F%u_F", 18, 27)
 )>;
 
-def FPR32C : RegisterClass<"RISCV", [f32], 32, (add
+def FPR32C : RISCVRegisterClass<[f32], 32, (add
   (sequence "F%u_F", 15, 10),
   (sequence "F%u_F", 8, 9)
 )>;
 
 // The order of registers represents the preferred allocation sequence,
 // meaning caller-save regs are listed before callee-save.
-def FPR64 : RegisterClass<"RISCV", [f64], 64, (add
+def FPR64 : RISCVRegisterClass<[f64], 64, (add
     (sequence "F%u_D", 15, 10),
     (sequence "F%u_D", 0, 7),
     (sequence "F%u_D", 16, 17),
@@ -263,7 +276,7 @@ def FPR64 : RegisterClass<"RISCV", [f64], 64, (add
     (sequence "F%u_D", 18, 27)
 )>;
 
-def FPR64C : RegisterClass<"RISCV", [f64], 64, (add
+def FPR64C : RISCVRegisterClass<[f64], 64, (add
   (sequence "F%u_D", 15, 10),
   (sequence "F%u_D", 8, 9)
 )>;
@@ -464,8 +477,8 @@ let isConstant = true in
 def VLENB  : RISCVReg<0, "vlenb">,
              DwarfRegNum<[!add(4096, SysRegVLENB.Encoding)]>;
 
-def VCSR : RegisterClass<"RISCV", [XLenVT], 32,
-                          (add VTYPE, VL, VLENB)> {
+def VCSR : RISCVRegisterClass<[XLenVT], 32,
+                              (add VTYPE, VL, VLENB)> {
   let RegInfos = XLenRI;
   let isAllocatable = 0;
 }
@@ -483,12 +496,11 @@ foreach m = [1, 2, 4] in {
 }
 
 class VReg<list<ValueType> regTypes, dag regList, int Vlmul>
-  : RegisterClass<"RISCV",
-                  regTypes,
-                  64, // The maximum supported ELEN is 64.
-                  regList> {
-  int VLMul = Vlmul;
-  int Size = !mul(Vlmul, 64);
+    : RISCVRegisterClass<regTypes,
+                         64, // The maximum supported ELEN is 64.
+                         regList> {
+  let IsVRegClass = 1;
+  let VLMul = Vlmul;
 }
 
 defvar VMaskVTs = [vbool1_t, vbool2_t, vbool4_t, vbool8_t, vbool16_t,
@@ -537,13 +549,11 @@ def VRM8 : VReg<VM8VTs, (add V8M8, V16M8, V24M8, V0M8), 8>;
 
 def VRM8NoV0 : VReg<VM8VTs, (sub VRM8, V0M8), 8>;
 
-def VMV0 : RegisterClass<"RISCV", VMaskVTs, 64, (add V0)> {
-  let Size = 64;
-}
+def VMV0 : VReg<VMaskVTs, (add V0), 1>;
 
 let RegInfos = XLenRI in {
-def GPRF16  : RegisterClass<"RISCV", [f16], 16, (add GPR)>;
-def GPRF32  : RegisterClass<"RISCV", [f32], 32, (add GPR)>;
+def GPRF16  : RISCVRegisterClass<[f16], 16, (add GPR)>;
+def GPRF32  : RISCVRegisterClass<[f32], 32, (add GPR)>;
 } // RegInfos = XLenRI
 
 // Dummy zero register for use in the register pair containing X0 (as X1 is
@@ -580,7 +590,7 @@ let RegAltNameIndices = [ABIRegAltName] in {
 let RegInfos = RegInfoByHwMode<[RV32, RV64],
                                [RegInfo<64, 64, 32>, RegInfo<128, 128, 64>]>,
     DecoderMethod = "DecodeGPRPairRegisterClass" in
-def GPRPair : RegisterClass<"RISCV", [XLenPairFVT], 64, (add
+def GPRPair : RISCVRegisterClass<[XLenPairFVT], 64, (add
     X10_X11, X12_X13, X14_X15, X16_X17,
     X6_X7,
     X28_X29, X30_X31,
@@ -594,13 +604,17 @@ def VM : VReg<VMaskVTs, (add VR), 1>;
 
 foreach m = LMULList in {
   foreach nf = NFList<m>.L in {
-    def "VRN" # nf # "M" # m # "NoV0": VReg<[untyped],
-                               (add !cast<RegisterTuples>("VN" # nf # "M" # m # "NoV0")),
-                                    !mul(nf, m)>;
-    def "VRN" # nf # "M" # m: VReg<[untyped],
-                               (add !cast<RegisterTuples>("VN" # nf # "M" # m # "NoV0"),
-                                    !cast<RegisterTuples>("VN" # nf # "M" # m # "V0")),
-                                    !mul(nf, m)>;
+    let NF = nf in {
+      def "VRN" # nf # "M" # m # "NoV0"
+        : VReg<[untyped],
+               (add !cast<RegisterTuples>("VN" # nf # "M" # m # "NoV0")),
+               m>;
+      def "VRN" # nf # "M" # m
+        : VReg<[untyped],
+               (add !cast<RegisterTuples>("VN" # nf # "M" # m # "NoV0"),
+                    !cast<RegisterTuples>("VN" # nf # "M" # m # "V0")),
+               m>;
+    }
   }
 }
 


        


More information about the llvm-commits mailing list