[llvm] [RISCV] Add copyPhysRegVector to extract common vector code out of copyPhysRegVector. (PR #70497)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 12:12:17 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

Call this method directly from each vector case with the correct
arguments. This allows us to treat each type of copy as its own
special case and not pass variables to a common merge point. This
is similar to how AArch64 is structured.
    
I think I can reduce the number of operands to this new method, but
I'll do that as a follow up to make this patch easier to review.

Stacked on #<!-- -->70492

---
Full diff: https://github.com/llvm/llvm-project/pull/70497.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+200-170) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.h (+6) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index ad31b2974993c74..9e4e86100a2115b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -294,6 +294,99 @@ static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,
   return false;
 }
 
+void RISCVInstrInfo::copyPhysRegVector(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+    unsigned Opc, unsigned NF, RISCVII::VLMUL LMul, unsigned SubRegIdx) const {
+  const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+
+  bool UseVMV_V_V = false;
+  bool UseVMV_V_I = false;
+  MachineBasicBlock::const_iterator DefMBBI;
+  if (isConvertibleToVMV_V_V(STI, MBB, MBBI, DefMBBI, LMul)) {
+    UseVMV_V_V = true;
+    // We only need to handle LMUL = 1/2/4/8 here because we only define
+    // vector register classes for LMUL = 1/2/4/8.
+    unsigned VIOpc;
+    switch (LMul) {
+    default:
+      llvm_unreachable("Impossible LMUL for vector register copy.");
+    case RISCVII::LMUL_1:
+      Opc = RISCV::PseudoVMV_V_V_M1;
+      VIOpc = RISCV::PseudoVMV_V_I_M1;
+      break;
+    case RISCVII::LMUL_2:
+      Opc = RISCV::PseudoVMV_V_V_M2;
+      VIOpc = RISCV::PseudoVMV_V_I_M2;
+      break;
+    case RISCVII::LMUL_4:
+      Opc = RISCV::PseudoVMV_V_V_M4;
+      VIOpc = RISCV::PseudoVMV_V_I_M4;
+      break;
+    case RISCVII::LMUL_8:
+      Opc = RISCV::PseudoVMV_V_V_M8;
+      VIOpc = RISCV::PseudoVMV_V_I_M8;
+      break;
+    }
+
+    if (DefMBBI->getOpcode() == VIOpc) {
+      UseVMV_V_I = true;
+      Opc = VIOpc;
+    }
+  }
+
+  if (NF == 1) {
+    auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), DstReg);
+    if (UseVMV_V_V)
+      MIB.addReg(DstReg, RegState::Undef);
+    if (UseVMV_V_I)
+      MIB = MIB.add(DefMBBI->getOperand(2));
+    else
+      MIB = MIB.addReg(SrcReg, getKillRegState(KillSrc));
+    if (UseVMV_V_V) {
+      const MCInstrDesc &Desc = DefMBBI->getDesc();
+      MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc)));  // AVL
+      MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
+      MIB.addImm(0);                                            // tu, mu
+      MIB.addReg(RISCV::VL, RegState::Implicit);
+      MIB.addReg(RISCV::VTYPE, RegState::Implicit);
+    }
+  } else {
+    int I = 0, End = NF, Incr = 1;
+    unsigned SrcEncoding = TRI->getEncodingValue(SrcReg);
+    unsigned DstEncoding = TRI->getEncodingValue(DstReg);
+    unsigned LMulVal;
+    bool Fractional;
+    std::tie(LMulVal, Fractional) = RISCVVType::decodeVLMUL(LMul);
+    assert(!Fractional && "It is impossible be fractional lmul here.");
+    if (forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NF * LMulVal)) {
+      I = NF - 1;
+      End = -1;
+      Incr = -1;
+    }
+
+    for (; I != End; I += Incr) {
+      auto MIB = BuildMI(MBB, MBBI, DL, get(Opc),
+                         TRI->getSubReg(DstReg, SubRegIdx + I));
+      if (UseVMV_V_V)
+        MIB.addReg(TRI->getSubReg(DstReg, SubRegIdx + I), RegState::Undef);
+      if (UseVMV_V_I)
+        MIB = MIB.add(DefMBBI->getOperand(2));
+      else
+        MIB = MIB.addReg(TRI->getSubReg(SrcReg, SubRegIdx + I),
+                         getKillRegState(KillSrc));
+      if (UseVMV_V_V) {
+        const MCInstrDesc &Desc = DefMBBI->getDesc();
+        MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc)));  // AVL
+        MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
+        MIB.addImm(0);                                            // tu, mu
+        MIB.addReg(RISCV::VL, RegState::Implicit);
+        MIB.addReg(RISCV::VTYPE, RegState::Implicit);
+      }
+    }
+  }
+}
+
 void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                  MachineBasicBlock::iterator MBBI,
                                  const DebugLoc &DL, MCRegister DstReg,
@@ -330,13 +423,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     return;
   }
 
-  // FPR->FPR copies and VR->VR copies.
-  unsigned Opc;
-  bool IsScalableVector = true;
-  unsigned NF = 1;
-  RISCVII::VLMUL LMul = RISCVII::LMUL_1;
-  unsigned SubRegIdx = RISCV::sub_vrm1_0;
   if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
+    unsigned Opc;
     if (STI.hasStdExtZfh()) {
       Opc = RISCV::FSGNJ_H;
     } else {
@@ -350,176 +438,118 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                         &RISCV::FPR32RegClass);
       Opc = RISCV::FSGNJ_S;
     }
-    IsScalableVector = false;
-  } else if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::FSGNJ_S;
-    IsScalableVector = false;
-  } else if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::FSGNJ_D;
-    IsScalableVector = false;
-  } else if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV4R_V;
-    LMul = RISCVII::LMUL_4;
-  } else if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV8R_V;
-    LMul = RISCVII::LMUL_8;
-  } else if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV4R_V;
-    SubRegIdx = RISCV::sub_vrm4_0;
-    NF = 2;
-    LMul = RISCVII::LMUL_4;
-  } else if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 3;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 3;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 4;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV2R_V;
-    SubRegIdx = RISCV::sub_vrm2_0;
-    NF = 4;
-    LMul = RISCVII::LMUL_2;
-  } else if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 5;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 6;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 7;
-    LMul = RISCVII::LMUL_1;
-  } else if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
-    Opc = RISCV::VMV1R_V;
-    SubRegIdx = RISCV::sub_vrm1_0;
-    NF = 8;
-    LMul = RISCVII::LMUL_1;
-  } else {
-    llvm_unreachable("Impossible reg-to-reg copy");
+    BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
   }
 
-  if (IsScalableVector) {
-    bool UseVMV_V_V = false;
-    bool UseVMV_V_I = false;
-    MachineBasicBlock::const_iterator DefMBBI;
-    if (isConvertibleToVMV_V_V(STI, MBB, MBBI, DefMBBI, LMul)) {
-      UseVMV_V_V = true;
-      // We only need to handle LMUL = 1/2/4/8 here because we only define
-      // vector register classes for LMUL = 1/2/4/8.
-      unsigned VIOpc;
-      switch (LMul) {
-      default:
-        llvm_unreachable("Impossible LMUL for vector register copy.");
-      case RISCVII::LMUL_1:
-        Opc = RISCV::PseudoVMV_V_V_M1;
-        VIOpc = RISCV::PseudoVMV_V_I_M1;
-        break;
-      case RISCVII::LMUL_2:
-        Opc = RISCV::PseudoVMV_V_V_M2;
-        VIOpc = RISCV::PseudoVMV_V_I_M2;
-        break;
-      case RISCVII::LMUL_4:
-        Opc = RISCV::PseudoVMV_V_V_M4;
-        VIOpc = RISCV::PseudoVMV_V_I_M4;
-        break;
-      case RISCVII::LMUL_8:
-        Opc = RISCV::PseudoVMV_V_V_M8;
-        VIOpc = RISCV::PseudoVMV_V_I_M8;
-        break;
-      }
-
-      if (DefMBBI->getOpcode() == VIOpc) {
-        UseVMV_V_I = true;
-        Opc = VIOpc;
-      }
-    }
-
-    if (NF == 1) {
-      auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), DstReg);
-      if (UseVMV_V_V)
-        MIB.addReg(DstReg, RegState::Undef);
-      if (UseVMV_V_I)
-        MIB = MIB.add(DefMBBI->getOperand(2));
-      else
-        MIB = MIB.addReg(SrcReg, getKillRegState(KillSrc));
-      if (UseVMV_V_V) {
-        const MCInstrDesc &Desc = DefMBBI->getDesc();
-        MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
-        MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
-        MIB.addImm(0); // tu, mu
-        MIB.addReg(RISCV::VL, RegState::Implicit);
-        MIB.addReg(RISCV::VTYPE, RegState::Implicit);
-      }
-    } else {
-      int I = 0, End = NF, Incr = 1;
-      unsigned SrcEncoding = TRI->getEncodingValue(SrcReg);
-      unsigned DstEncoding = TRI->getEncodingValue(DstReg);
-      unsigned LMulVal;
-      bool Fractional;
-      std::tie(LMulVal, Fractional) = RISCVVType::decodeVLMUL(LMul);
-      assert(!Fractional && "It is impossible be fractional lmul here.");
-      if (forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NF * LMulVal)) {
-        I = NF - 1;
-        End = -1;
-        Incr = -1;
-      }
+  if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_S), DstReg)
+        .addReg(SrcReg, getKillRegState(KillSrc))
+        .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
+  }
 
-      for (; I != End; I += Incr) {
-        auto MIB = BuildMI(MBB, MBBI, DL, get(Opc),
-                           TRI->getSubReg(DstReg, SubRegIdx + I));
-        if (UseVMV_V_V)
-          MIB.addReg(TRI->getSubReg(DstReg, SubRegIdx + I),
-                     RegState::Undef);
-        if (UseVMV_V_I)
-          MIB = MIB.add(DefMBBI->getOperand(2));
-        else
-          MIB = MIB.addReg(TRI->getSubReg(SrcReg, SubRegIdx + I),
-                           getKillRegState(KillSrc));
-        if (UseVMV_V_V) {
-          const MCInstrDesc &Desc = DefMBBI->getDesc();
-          MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
-          MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
-          MIB.addImm(0);  // tu, mu
-          MIB.addReg(RISCV::VL, RegState::Implicit);
-          MIB.addReg(RISCV::VTYPE, RegState::Implicit);
-        }
-      }
-    }
-  } else {
-    BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
+  if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
+    BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_D), DstReg)
         .addReg(SrcReg, getKillRegState(KillSrc))
         .addReg(SrcReg, getKillRegState(KillSrc));
+    return;
+  }
+
+  // VR->VR copies.
+  if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/1, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/1, RISCVII::LMUL_2, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+                      /*NF=*/1, RISCVII::LMUL_4, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV8R_V,
+                      /*NF=*/1, RISCVII::LMUL_8, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/2, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/2, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+                      /*NF=*/2, RISCVII::LMUL_4, RISCV::sub_vrm4_0);
+    return;
+  }
+
+  if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/3, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/3, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/4, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+                      /*NF=*/4, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+    return;
+  }
+
+  if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/5, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/6, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
   }
+
+  if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/7, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
+    copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+                      /*NF=*/8, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+    return;
+  }
+
+  llvm_unreachable("Impossible reg-to-reg copy");
 }
 
 void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 5584e5571c9bc35..4b93d44ed2d5d85 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -14,6 +14,7 @@
 #define LLVM_LIB_TARGET_RISCV_RISCVINSTRINFO_H
 
 #include "RISCVRegisterInfo.h"
+#include "MCTargetDesc/RISCVBaseInfo.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
 #include "llvm/IR/DiagnosticInfo.h"
 
@@ -63,6 +64,11 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
   unsigned isStoreToStackSlot(const MachineInstr &MI, int &FrameIndex,
                               unsigned &MemBytes) const override;
 
+  void copyPhysRegVector(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI, const DebugLoc &DL,
+                         MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+                         unsigned Opc, unsigned NF, RISCVII::VLMUL LMul,
+                         unsigned SubRegIdx) const;
   void copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                    const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg,
                    bool KillSrc) const override;

``````````

</details>


https://github.com/llvm/llvm-project/pull/70497


More information about the llvm-commits mailing list