[llvm] [RISCV] Add copyPhysRegVector to extract common vector code out of copyPhysRegVector. (PR #70497)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 12:29:11 PDT 2023


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/70497

>From 550e370202572e01a9419a95197ff0b4003e2940 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 27 Oct 2023 12:06:31 -0700
Subject: [PATCH] [RISCV] Add copyPhysRegVector to extract common vector code
 out of copyPhysRegVector.

Call this method directly from each vector case with the correct
arguments. This allows us to treat each type of copy as its own
special case and not pass variables to a common merge point. This
is similar to how AArch64 is structured.

I think I can reduce the number of operands to this new method, but
I'll do that as a follow up to make this patch easier to review.
---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 316 ++++++++++++-----------
 llvm/lib/Target/RISCV/RISCVInstrInfo.h   |   6 +
 2 files changed, 176 insertions(+), 146 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 2f3d2084be70304..9e4e86100a2115b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -294,153 +294,12 @@ static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,
   return false;
 }
 
-void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
-                                 MachineBasicBlock::iterator MBBI,
-                                 const DebugLoc &DL, MCRegister DstReg,
-                                 MCRegister SrcReg, bool KillSrc) const {
+void RISCVInstrInfo::copyPhysRegVector(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+    unsigned Opc, unsigned NF, RISCVII::VLMUL LMul, unsigned SubRegIdx) const {
   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
 
-  if (RISCV::GPRRegClass.contains(DstReg, SrcReg)) {
-    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI), DstReg)
-        .addReg(SrcReg, getKillRegState(KillSrc))
-        .addImm(0);
-    return;
-  }
-
-  if (RISCV::GPRPF64RegClass.contains(DstReg, SrcReg)) {
-    // Emit an ADDI for both parts of GPRPF64.
-    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
-            TRI->getSubReg(DstReg, RISCV::sub_32))
-        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32), getKillRegState(KillSrc))
-        .addImm(0);
-    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
-            TRI->getSubReg(DstReg, RISCV::sub_32_hi))
-        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32_hi),
-                getKillRegState(KillSrc))
-        .addImm(0);
-    return;
-  }
-
-  // Handle copy from csr
-  if (RISCV::VCSRRegClass.contains(SrcReg) &&
-      RISCV::GPRRegClass.contains(DstReg)) {
-    BuildMI(MBB, MBBI, DL, get(RISCV::CSRRS), DstReg)
-        .addImm(RISCVSysReg::lookupSysRegByName(TRI->getName(SrcReg))->Encoding)
-        .addReg(RISCV::X0);
-    return;
-  }
-
-  if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
-    unsigned Opc;
-    if (STI.hasStdExtZfh()) {
-      Opc = RISCV::FSGNJ_H;
-    } else {
-      assert(STI.hasStdExtF() &&
-             (STI.hasStdExtZfhmin() || STI.hasStdExtZfbfmin()) &&
-             "Unexpected extensions");
-      // Zfhmin/Zfbfmin doesn't have FSGNJ_H, replace FSGNJ_H with FSGNJ_S.
-      DstReg = TRI->getMatchingSuperReg(DstReg, RISCV::sub_16,
-                                        &RISCV::FPR32RegClass);
-      SrcReg = TRI->getMatchingSuperReg(SrcReg, RISCV::sub_16,
-                                        &RISCV::FPR32RegClass);
-      Opc = RISCV::FSGNJ_S;
-    }
-    BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
-        .addReg(SrcReg, getKillRegState(KillSrc))
-        .addReg(SrcReg, getKillRegState(KillSrc));
-    return;
-  }
-
-  if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
-    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_S), DstReg)
-        .addReg(SrcReg, getKillRegState(KillSrc))
-        .addReg(SrcReg, getKillRegState(KillSrc));
-    return;
-  }
-
-  if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
-    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_D), DstReg)
-        .addReg(SrcReg, getKillRegState(KillSrc))
-        .addReg(SrcReg, getKillRegState(KillSrc));
-    return;
-  }
-
-  // VR->VR copies.
-  unsigned Opc;
-  unsigned NF = 1;
-  RISCVII::VLMUL LMul = RISCVII::LMUL_1;
-  unsigned SubRegIdx = RISCV::sub_vrm1_0;
-  if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV4R_V;
-    LMul = RISCVII::LMUL_4;
-  } else if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV8R_V;
-    LMul = RISCVII::LMUL_8;
-  } else if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV4R_V;
-    SubRegIdx = RISCV::sub_vrm4_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_4;
-  } else if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 3;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 3;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 4;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 4;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 5;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 6;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 7;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 8;
-    LMul = RISCVII::LMUL_1;
-  } else {
-    llvm_unreachable("Impossible reg-to-reg copy");
-  }
-
   bool UseVMV_V_V = false;
   bool UseVMV_V_I = false;
   MachineBasicBlock::const_iterator DefMBBI;
@@ -518,7 +377,7 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                          getKillRegState(KillSrc));
       if (UseVMV_V_V) {
         const MCInstrDesc &Desc = DefMBBI->getDesc();
-        MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
+        MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc)));  // AVL
         MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
         MIB.addImm(0);                                            // tu, mu
         MIB.addReg(RISCV::VL, RegState::Implicit);
@@ -528,6 +387,171 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   }
 }
 
+void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
+                                 MachineBasicBlock::iterator MBBI,
+                                 const DebugLoc &DL, MCRegister DstReg,
+                                 MCRegister SrcReg, bool KillSrc) const {
+  const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+
+  if (RISCV::GPRRegClass.contains(DstReg, SrcReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addImm(0);
+    return;
+  }
+
+  if (RISCV::GPRPF64RegClass.contains(DstReg, SrcReg)) {
+    // Emit an ADDI for both parts of GPRPF64.
+    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
+            TRI->getSubReg(DstReg, RISCV::sub_32))
+        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32), getKillRegState(KillSrc))
+        .addImm(0);
+    BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
+            TRI->getSubReg(DstReg, RISCV::sub_32_hi))
+        .addReg(TRI->getSubReg(SrcReg, RISCV::sub_32_hi),
+                getKillRegState(KillSrc))
+        .addImm(0);
+    return;
+  }
+
+  // Handle copy from csr
+  if (RISCV::VCSRRegClass.contains(SrcReg) &&
+      RISCV::GPRRegClass.contains(DstReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::CSRRS), DstReg)
+        .addImm(RISCVSysReg::lookupSysRegByName(TRI->getName(SrcReg))->Encoding)
+        .addReg(RISCV::X0);
+    return;
+  }
+
+  if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
+    unsigned Opc;
+    if (STI.hasStdExtZfh()) {
+      Opc = RISCV::FSGNJ_H;
+    } else {
+      assert(STI.hasStdExtF() &&
+             (STI.hasStdExtZfhmin() || STI.hasStdExtZfbfmin()) &&
+             "Unexpected extensions");
+      // Zfhmin/Zfbfmin doesn't have FSGNJ_H, replace FSGNJ_H with FSGNJ_S.
+      DstReg = TRI->getMatchingSuperReg(DstReg, RISCV::sub_16,
+                                        &RISCV::FPR32RegClass);
+      SrcReg = TRI->getMatchingSuperReg(SrcReg, RISCV::sub_16,
+                                        &RISCV::FPR32RegClass);
+      Opc = RISCV::FSGNJ_S;
+    }
+    BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
+  }
+
+  if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_S), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
+  }
+
+  if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_D), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
+  }
+
+  // VR->VR copies.
+  if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/1, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/1, RISCVII::LMUL_2, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+                      /*NF=*/1, RISCVII::LMUL_4, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV8R_V,
+                      /*NF=*/1, RISCVII::LMUL_8, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/2, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/2, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+                      /*NF=*/2, RISCVII::LMUL_4, RISCV::sub_vrm4_0);
+    return;
+  }
+
+  if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/3, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/3, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/4, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/4, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/5, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/6, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/7, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/8, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  llvm_unreachable("Impossible reg-to-reg copy");
+}
+
 void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
                                          MachineBasicBlock::iterator I,
                                          Register SrcReg, bool IsKill, int FI,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 5584e5571c9bc35..13f1bd4127b12b7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_LIB_TARGET_RISCV_RISCVINSTRINFO_H
 #define LLVM_LIB_TARGET_RISCV_RISCVINSTRINFO_H
 
+#include "MCTargetDesc/RISCVBaseInfo.h"
 #include "RISCVRegisterInfo.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -63,6 +64,11 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
   unsigned isStoreToStackSlot(const MachineInstr &MI, int &FrameIndex,
                               unsigned &MemBytes) const override;
 
+  void copyPhysRegVector(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI, const DebugLoc &DL,
+                         MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+                         unsigned Opc, unsigned NF, RISCVII::VLMUL LMul,
+                         unsigned SubRegIdx) const;
   void copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg,
                    bool KillSrc) const override;



More information about the llvm-commits mailing list