[llvm] [RISCV][GISel] Sink getOperandsMapping call out of the switch in getInstrMapping. (PR #72326)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 15:54:25 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

Use a SmallVector of `ValueMapping *` that we populate in the switch for each register operand. The entry in the SmallVector defaults to nullptr for each operand so we don't need to write explicit `nullptr` in the cases.

After this we can generically fill in GPR for registers as a default case and remove some opcodes from the switch.

---
Full diff: https://github.com/llvm/llvm-project/pull/72326.diff


1 Files Affected:

- (modified) llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp (+55-54) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index cb1da8ff11c08cb..be9c735c875466b 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -181,12 +181,8 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   const ValueMapping *GPRValueMapping =
       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
                                           : RISCV::GPRB32Idx];
-  const ValueMapping *OperandsMapping = GPRValueMapping;
 
   switch (Opc) {
-  case TargetOpcode::G_INVOKE_REGION_START:
-    OperandsMapping = getOperandsMapping({});
-    break;
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_SUB:
   case TargetOpcode::G_SHL:
@@ -215,14 +211,37 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_ZEXT:
   case TargetOpcode::G_SEXTLOAD:
   case TargetOpcode::G_ZEXTLOAD:
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
+                                 NumOperands);
+  case TargetOpcode::G_FADD:
+  case TargetOpcode::G_FSUB:
+  case TargetOpcode::G_FMUL:
+  case TargetOpcode::G_FDIV:
+  case TargetOpcode::G_FABS:
+  case TargetOpcode::G_FNEG:
+  case TargetOpcode::G_FSQRT:
+  case TargetOpcode::G_FMAXNUM:
+  case TargetOpcode::G_FMINNUM: {
+    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                                 getFPValueMapping(Ty.getSizeInBits()),
+                                 NumOperands);
+  }
+  }
+
+  SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
+
+  switch (Opc) {
+  case TargetOpcode::G_INVOKE_REGION_START:
     break;
   case TargetOpcode::G_LOAD: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 loads on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
@@ -237,26 +256,25 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
                  // instruction.
                  return onlyUsesFP(UseMI);
                })) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
 
     break;
   }
   case TargetOpcode::G_STORE: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 stores on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
     if (onlyDefinesFP(*DefMI)) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
     break;
   }
@@ -265,76 +283,60 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_GLOBAL_VALUE:
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_BRCOND:
-    OperandsMapping = getOperandsMapping({GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
     break;
   case TargetOpcode::G_BR:
-    OperandsMapping = getOperandsMapping({nullptr});
     break;
   case TargetOpcode::G_BRJT:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, nullptr, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
     break;
   case TargetOpcode::G_ICMP:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
   case TargetOpcode::G_SEXT_INREG:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     break;
   case TargetOpcode::G_SELECT:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, GPRValueMapping, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
-  case TargetOpcode::G_FADD:
-  case TargetOpcode::G_FSUB:
-  case TargetOpcode::G_FMUL:
-  case TargetOpcode::G_FDIV:
-  case TargetOpcode::G_FNEG:
-  case TargetOpcode::G_FABS:
-  case TargetOpcode::G_FSQRT:
-  case TargetOpcode::G_FMAXNUM:
-  case TargetOpcode::G_FMINNUM: {
-    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getFPValueMapping(Ty.getSizeInBits());
-    break;
-  }
   case TargetOpcode::G_FMA: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    const RegisterBankInfo::ValueMapping *FPValueMapping =
-        getFPValueMapping(Ty.getSizeInBits());
-    OperandsMapping = getOperandsMapping(
-        {FPValueMapping, FPValueMapping, FPValueMapping, FPValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = OpdsMapping[2] = OpdsMapping[3] = OpdsMapping[0];
     break;
   }
   case TargetOpcode::G_FPEXT:
   case TargetOpcode::G_FPTRUNC: {
     LLT ToTy = MRI.getType(MI.getOperand(0).getReg());
     LLT FromTy = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(ToTy.getSizeInBits()),
-                            getFPValueMapping(FromTy.getSizeInBits())});
+    OpdsMapping[0] = getFPValueMapping(ToTy.getSizeInBits());
+    OpdsMapping[1] = getFPValueMapping(FromTy.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FPTOSI:
   case TargetOpcode::G_FPTOUI: {
     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping,
-                            getFPValueMapping(Ty.getSizeInBits())});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_SITOFP:
   case TargetOpcode::G_UITOFP: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getOperandsMapping(
-        {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = GPRValueMapping;
     break;
   }
   case TargetOpcode::G_FCONSTANT: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(Ty.getSizeInBits()), nullptr});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FCMP: {
@@ -343,15 +345,14 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     unsigned Size = Ty.getSizeInBits();
     assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
 
-    auto *FPRValueMapping = getFPValueMapping(Size);
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, FPRValueMapping, FPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
     break;
   }
   default:
     return getInvalidInstructionMapping();
   }
 
-  return getInstructionMapping(DefaultMappingID, /*Cost=*/1, OperandsMapping,
-                               NumOperands);
+  return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                               getOperandsMapping(OpdsMapping), NumOperands);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/72326


More information about the llvm-commits mailing list