[llvm] [RISCV] Refactor GPRF64 register class to make it usable for Zacas. (PR #77408)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 9 00:03:54 PST 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/77408

>From a69fbf642552ed3ad6dd0a726685f165c11f4a3d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 8 Jan 2024 19:49:11 -0800
Subject: [PATCH 1/3] [RISCV] Refactor GPRF64 register class to make it usable
 for Zacas.

-Rename to GPRPair.
-Rename registers to be named like X10_X11 instead of X10_PD. Except X0
 which is now X0_Pair since it is not paired with X1.
-Use unknown size and offset for the subreg indices. This might
 be a functional change, but does not affect any lit tests.
---
 .../Target/RISCV/AsmParser/RISCVAsmParser.cpp |  2 +-
 .../RISCV/Disassembler/RISCVDisassembler.cpp  |  2 +-
 .../Target/RISCV/RISCVExpandPseudoInsts.cpp   | 12 ++++--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 10 ++---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      | 17 ++++----
 llvm/lib/Target/RISCV/RISCVInstrInfoD.td      | 17 ++++----
 llvm/lib/Target/RISCV/RISCVRegisterInfo.td    | 42 +++++++++++--------
 7 files changed, 57 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index d616aaeddf4114..4250950a917299 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -1295,7 +1295,7 @@ unsigned RISCVAsmParser::checkTargetMatchPredicate(MCInst &Inst) {
   const MCInstrDesc &MCID = MII.get(Inst.getOpcode());
 
   for (unsigned I = 0; I < MCID.NumOperands; ++I) {
-    if (MCID.operands()[I].RegClass == RISCV::GPRPF64RegClassID) {
+    if (MCID.operands()[I].RegClass == RISCV::GPRPairRegClassID) {
       const auto &Op = Inst.getOperand(I);
       assert(Op.isReg());
 
diff --git a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp
index ed80da14c79574..bc65cf2403b262 100644
--- a/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp
+++ b/llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp
@@ -171,7 +171,7 @@ static DecodeStatus DecodeGPRCRegisterClass(MCInst &Inst, uint32_t RegNo,
   return MCDisassembler::Success;
 }
 
-static DecodeStatus DecodeGPRPF64RegisterClass(MCInst &Inst, uint32_t RegNo,
+static DecodeStatus DecodeGPRPairRegisterClass(MCInst &Inst, uint32_t RegNo,
                                                uint64_t Address,
                                                const MCDisassembler *Decoder) {
   if (RegNo >= 32 || RegNo & 1)
diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
index 24a13f93af880e..7592392de53b9c 100644
--- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
@@ -300,8 +300,10 @@ bool RISCVExpandPseudo::expandRV32ZdinxStore(MachineBasicBlock &MBB,
                                              MachineBasicBlock::iterator MBBI) {
   DebugLoc DL = MBBI->getDebugLoc();
   const TargetRegisterInfo *TRI = STI->getRegisterInfo();
-  Register Lo = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32);
-  Register Hi = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32_hi);
+  Register Lo =
+      TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_even);
+  Register Hi =
+      TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_odd);
   BuildMI(MBB, MBBI, DL, TII->get(RISCV::SW))
       .addReg(Lo, getKillRegState(MBBI->getOperand(0).isKill()))
       .addReg(MBBI->getOperand(1).getReg())
@@ -334,8 +336,10 @@ bool RISCVExpandPseudo::expandRV32ZdinxLoad(MachineBasicBlock &MBB,
                                             MachineBasicBlock::iterator MBBI) {
   DebugLoc DL = MBBI->getDebugLoc();
   const TargetRegisterInfo *TRI = STI->getRegisterInfo();
-  Register Lo = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32);
-  Register Hi = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32_hi);
+  Register Lo =
+      TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_even);
+  Register Hi =
+      TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_odd);
 
   // If the register of operand 1 is equal to the Lo register, then swap the
   // order of loading the Lo and Hi statements.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 79c16cf4c4c361..abadd08a4d90d9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -138,7 +138,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     if (Subtarget.is64Bit())
       addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
     else
-      addRegisterClass(MVT::f64, &RISCV::GPRPF64RegClass);
+      addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
   }
 
   static const MVT::SimpleValueType BoolVecVTs[] = {
@@ -16345,7 +16345,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
   Register SrcReg = MI.getOperand(2).getReg();
 
   const TargetRegisterClass *SrcRC = MI.getOpcode() == RISCV::SplitF64Pseudo_INX
-                                         ? &RISCV::GPRPF64RegClass
+                                         ? &RISCV::GPRPairRegClass
                                          : &RISCV::FPR64RegClass;
   int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
 
@@ -16384,7 +16384,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
   Register HiReg = MI.getOperand(2).getReg();
 
   const TargetRegisterClass *DstRC =
-      MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX ? &RISCV::GPRPF64RegClass
+      MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX ? &RISCV::GPRPairRegClass
                                                       : &RISCV::FPR64RegClass;
   int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);
 
@@ -18751,7 +18751,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
         return std::make_pair(0U, &RISCV::GPRF32RegClass);
       if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
-        return std::make_pair(0U, &RISCV::GPRPF64RegClass);
+        return std::make_pair(0U, &RISCV::GPRPairRegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
       if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
@@ -18933,7 +18933,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   // Subtarget into account.
   if (Res.second == &RISCV::GPRF16RegClass ||
       Res.second == &RISCV::GPRF32RegClass ||
-      Res.second == &RISCV::GPRPF64RegClass)
+      Res.second == &RISCV::GPRPairRegClass)
     return std::make_pair(Res.first, &RISCV::GPRRegClass);
 
   return Res;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 7f6a045a7d042f..388e0db27046e2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -414,15 +414,16 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     return;
   }
 
-  if (RISCV::GPRPF64RegClass.contains(DstReg, SrcReg)) {
-    // Emit an ADDI for both parts of GPRPF64.
+  if (RISCV::GPRPairRegClass.contains(DstReg, SrcReg)) {
+    // Emit an ADDI for both parts of GPRPair.
     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
-            TRI->getSubReg(DstReg, RISCV::sub_32))
-        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32), getKillRegState(KillSrc))
+            TRI->getSubReg(DstReg, RISCV::sub_gpr_even))
+        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_gpr_even),
+                getKillRegState(KillSrc))
         .addImm(0);
     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
-            TRI->getSubReg(DstReg, RISCV::sub_32_hi))
-        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32_hi),
+            TRI->getSubReg(DstReg, RISCV::sub_gpr_odd))
+        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_gpr_odd),
                 getKillRegState(KillSrc))
         .addImm(0);
     return;
@@ -607,7 +608,7 @@ void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
              RISCV::SW : RISCV::SD;
     IsScalableVector = false;
-  } else if (RISCV::GPRPF64RegClass.hasSubClassEq(RC)) {
+  } else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
     Opcode = RISCV::PseudoRV32ZdinxSD;
     IsScalableVector = false;
   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
@@ -690,7 +691,7 @@ void RISCVInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
              RISCV::LW : RISCV::LD;
     IsScalableVector = false;
-  } else if (RISCV::GPRPF64RegClass.hasSubClassEq(RC)) {
+  } else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
     Opcode = RISCV::PseudoRV32ZdinxLD;
     IsScalableVector = false;
   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index 418421b2a556f7..6e89934b4052c2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -33,8 +33,8 @@ def AddrRegImmINX : ComplexPattern<iPTR, 2, "SelectAddrRegImmINX">;
 
 // Zdinx
 
-def GPRPF64AsFPR : AsmOperandClass {
-  let Name = "GPRPF64AsFPR";
+def GPRPairAsFPR : AsmOperandClass {
+  let Name = "GPRPairAsFPR";
   let ParserMethod = "parseGPRAsFPR";
   let PredicateMethod = "isGPRAsFPR";
   let RenderMethod = "addRegOperands";
@@ -52,8 +52,9 @@ def FPR64INX : RegisterOperand<GPR> {
   let DecoderMethod = "DecodeGPRRegisterClass";
 }
 
-def FPR64IN32X : RegisterOperand<GPRPF64> {
-  let ParserMatchClass = GPRPF64AsFPR;
+def FPR64IN32X : RegisterOperand<GPRPair> {
+  let ParserMatchClass = GPRPairAsFPR;
+  let DecoderMethod = "DecodeGPRPairRegisterClass";
 }
 
 def DExt       : ExtInfo<"", "", [HasStdExtD], f64, FPR64, FPR32, FPR64, ?>;
@@ -515,15 +516,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;
 
 /// Loads
 let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxLD : Pseudo<(outs GPRPF64:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
+def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
 def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
           (PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;
 
 /// Stores
 let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
-def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPF64:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
-def : Pat<(store (f64 GPRPF64:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
-          (PseudoRV32ZdinxSD GPRPF64:$rs2, GPR:$rs1, simm12:$imm12)>;
+def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
+def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
+          (PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;
 
 /// Pseudo-instructions needed for the soft-float ABI with RV32D
 
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index a59d058382fe58..8d19d8e935a929 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -63,7 +63,10 @@ def sub_vrm1_5 : ComposedSubRegIndex<sub_vrm2_2, sub_vrm1_1>;
 def sub_vrm1_6 : ComposedSubRegIndex<sub_vrm2_3, sub_vrm1_0>;
 def sub_vrm1_7 : ComposedSubRegIndex<sub_vrm2_3, sub_vrm1_1>;
 
-def sub_32_hi  : SubRegIndex<32, 32>;
+// GPR sizes change with HwMode.
+// FIXME: Support HwMode in SubRegIndex?
+def sub_gpr_even : SubRegIndex<-1>;
+def sub_gpr_odd  : SubRegIndex<-1, -1>;
 } // Namespace = "RISCV"
 
 // Integer registers
@@ -546,33 +549,36 @@ def DUMMY_REG_PAIR_WITH_X0 : RISCVReg<0, "0">;
 def GPRAll : GPRRegisterClass<(add GPR, DUMMY_REG_PAIR_WITH_X0)>;
 
 let RegAltNameIndices = [ABIRegAltName] in {
-  def X0_PD : RISCVRegWithSubRegs<0, X0.AsmName,
-                                     [X0, DUMMY_REG_PAIR_WITH_X0],
-                                     X0.AltNames> {
-    let SubRegIndices = [sub_32, sub_32_hi];
+  def X0_Pair : RISCVRegWithSubRegs<0, X0.AsmName,
+                                    [X0, DUMMY_REG_PAIR_WITH_X0],
+                                    X0.AltNames> {
+    let SubRegIndices = [sub_gpr_even, sub_gpr_odd];
     let CoveredBySubRegs = 1;
   }
   foreach I = 1-15 in {
     defvar Index = !shl(I, 1);
+    defvar IndexP1 = !add(Index, 1);
     defvar Reg = !cast<Register>("X"#Index);
-    defvar RegP1 = !cast<Register>("X"#!add(Index,1));
-    def X#Index#_PD : RISCVRegWithSubRegs<Index, Reg.AsmName,
-                                          [Reg, RegP1],
-                                          Reg.AltNames> {
-      let SubRegIndices = [sub_32, sub_32_hi];
+    defvar RegP1 = !cast<Register>("X"#IndexP1);
+    def "X" # Index #"_X" # IndexP1 : RISCVRegWithSubRegs<Index,
+                                                          Reg.AsmName,
+                                                          [Reg, RegP1],
+                                                          Reg.AltNames> {
+      let SubRegIndices = [sub_gpr_even, sub_gpr_odd];
       let CoveredBySubRegs = 1;
     }
   }
 }
 
-let RegInfos = RegInfoByHwMode<[RV64], [RegInfo<64, 64, 64>]> in
-def GPRPF64 : RegisterClass<"RISCV", [f64], 64, (add
-    X10_PD, X12_PD, X14_PD, X16_PD,
-    X6_PD,
-    X28_PD, X30_PD,
-    X8_PD,
-    X18_PD, X20_PD, X22_PD, X24_PD, X26_PD,
-    X0_PD, X2_PD, X4_PD
+let RegInfos = RegInfoByHwMode<[RV64], [RegInfo<64, 64, 64>]>,
+    DecoderMethod = "DecodeGPRPairRegisterClass" in
+def GPRPair : RegisterClass<"RISCV", [f64], 64, (add
+    X10_X11, X12_X13, X14_X15, X16_X17,
+    X6_X7,
+    X28_X29, X30_X31,
+    X8_X9,
+    X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
+    X0_Pair, X2_X3, X4_X5
 )>;
 
 // The register class is added for inline assembly for vector mask types.

>From 71d7cea73b1e1373a5fb7ec04d2ac07a23bc5be3 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 8 Jan 2024 20:31:08 -0800
Subject: [PATCH 2/3] fixup! remove unneeded change

---
 llvm/lib/Target/RISCV/RISCVInstrInfoD.td | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index 6e89934b4052c2..fec43d814098ce 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -54,7 +54,6 @@ def FPR64INX : RegisterOperand<GPR> {
 
 def FPR64IN32X : RegisterOperand<GPRPair> {
   let ParserMatchClass = GPRPairAsFPR;
-  let DecoderMethod = "DecodeGPRPairRegisterClass";
 }
 
 def DExt       : ExtInfo<"", "", [HasStdExtD], f64, FPR64, FPR32, FPR64, ?>;

>From 47a8943cc17c8ff4b2cd40a42ad29c43deb12d10 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 9 Jan 2024 00:03:36 -0800
Subject: [PATCH 3/3] fixup! Update VT and RegInfos.

---
 llvm/lib/Target/RISCV/RISCVRegisterInfo.td | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
index 8d19d8e935a929..5a4d8c4cfece7f 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td
@@ -121,6 +121,8 @@ def XLenVT : ValueTypeByHwMode<[RV32, RV64],
 // Allow f64 in GPR for ZDINX on RV64.
 def XLenFVT : ValueTypeByHwMode<[RV64],
                                 [f64]>;
+def XLenPairFVT : ValueTypeByHwMode<[RV32],
+                                    [f64]>;
 def XLenRI : RegInfoByHwMode<
       [RV32,              RV64],
       [RegInfo<32,32,32>, RegInfo<64,64,64>]>;
@@ -570,9 +572,10 @@ let RegAltNameIndices = [ABIRegAltName] in {
   }
 }
 
-let RegInfos = RegInfoByHwMode<[RV64], [RegInfo<64, 64, 64>]>,
+let RegInfos = RegInfoByHwMode<[RV32, RV64],
+                               [RegInfo<64, 64, 64>, RegInfo<128, 128, 128>]>,
     DecoderMethod = "DecodeGPRPairRegisterClass" in
-def GPRPair : RegisterClass<"RISCV", [f64], 64, (add
+def GPRPair : RegisterClass<"RISCV", [XLenPairFVT], 64, (add
     X10_X11, X12_X13, X14_X15, X16_X17,
     X6_X7,
     X28_X29, X30_X31,



More information about the llvm-commits mailing list