[llvm] [RISCV][GISel] Sink getOperandsMapping call out of the switch in getInstrMapping. (PR #72326)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 15:53:58 PST 2023


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/72326

Use a SmallVector of `ValueMapping *` that we populate in the switch for each register operand. The entry in the SmallVector defaults to nullptr for each operand so we don't need to write explicit `nullptr` in the cases.

After this we can generically fill in GPR for registers as a default case and remove some opcodes from the switch.

>From 182894a141044e2014dc83785e57b3348bc812f1 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 14 Nov 2023 15:31:07 -0800
Subject: [PATCH] [RISCV][GISel] Sink getOperandsMapping call out of the switch
 in getInstrMapping.

Use a SmallVector of ValueMapping * that we populate in the switch for
each register operand. The entry in the SmallVector defaults to nullptr
foreach operand so we don't need to do explicit nullptr in the switch.

Afer this we can generically fill in GPR for registers as a default case and
remove some opcodes from the switch.
---
 .../RISCV/GISel/RISCVRegisterBankInfo.cpp     | 109 +++++++++---------
 1 file changed, 55 insertions(+), 54 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index cb1da8ff11c08cb..be9c735c875466b 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -181,12 +181,8 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   const ValueMapping *GPRValueMapping =
       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
                                           : RISCV::GPRB32Idx];
-  const ValueMapping *OperandsMapping = GPRValueMapping;
 
   switch (Opc) {
-  case TargetOpcode::G_INVOKE_REGION_START:
-    OperandsMapping = getOperandsMapping({});
-    break;
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_SUB:
   case TargetOpcode::G_SHL:
@@ -215,14 +211,37 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_ZEXT:
   case TargetOpcode::G_SEXTLOAD:
   case TargetOpcode::G_ZEXTLOAD:
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
+                                 NumOperands);
+  case TargetOpcode::G_FADD:
+  case TargetOpcode::G_FSUB:
+  case TargetOpcode::G_FMUL:
+  case TargetOpcode::G_FDIV:
+  case TargetOpcode::G_FABS:
+  case TargetOpcode::G_FNEG:
+  case TargetOpcode::G_FSQRT:
+  case TargetOpcode::G_FMAXNUM:
+  case TargetOpcode::G_FMINNUM: {
+    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                                 getFPValueMapping(Ty.getSizeInBits()),
+                                 NumOperands);
+  }
+  }
+
+  SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
+
+  switch (Opc) {
+  case TargetOpcode::G_INVOKE_REGION_START:
     break;
   case TargetOpcode::G_LOAD: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 loads on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
@@ -237,26 +256,25 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
                  // instruction.
                  return onlyUsesFP(UseMI);
                })) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
 
     break;
   }
   case TargetOpcode::G_STORE: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 stores on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
     if (onlyDefinesFP(*DefMI)) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     }
     break;
   }
@@ -265,76 +283,60 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_GLOBAL_VALUE:
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_BRCOND:
-    OperandsMapping = getOperandsMapping({GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
     break;
   case TargetOpcode::G_BR:
-    OperandsMapping = getOperandsMapping({nullptr});
     break;
   case TargetOpcode::G_BRJT:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, nullptr, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
     break;
   case TargetOpcode::G_ICMP:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
   case TargetOpcode::G_SEXT_INREG:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     break;
   case TargetOpcode::G_SELECT:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, GPRValueMapping, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
-  case TargetOpcode::G_FADD:
-  case TargetOpcode::G_FSUB:
-  case TargetOpcode::G_FMUL:
-  case TargetOpcode::G_FDIV:
-  case TargetOpcode::G_FNEG:
-  case TargetOpcode::G_FABS:
-  case TargetOpcode::G_FSQRT:
-  case TargetOpcode::G_FMAXNUM:
-  case TargetOpcode::G_FMINNUM: {
-    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getFPValueMapping(Ty.getSizeInBits());
-    break;
-  }
   case TargetOpcode::G_FMA: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    const RegisterBankInfo::ValueMapping *FPValueMapping =
-        getFPValueMapping(Ty.getSizeInBits());
-    OperandsMapping = getOperandsMapping(
-        {FPValueMapping, FPValueMapping, FPValueMapping, FPValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = OpdsMapping[2] = OpdsMapping[3] = OpdsMapping[0];
     break;
   }
   case TargetOpcode::G_FPEXT:
   case TargetOpcode::G_FPTRUNC: {
     LLT ToTy = MRI.getType(MI.getOperand(0).getReg());
     LLT FromTy = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(ToTy.getSizeInBits()),
-                            getFPValueMapping(FromTy.getSizeInBits())});
+    OpdsMapping[0] = getFPValueMapping(ToTy.getSizeInBits());
+    OpdsMapping[1] = getFPValueMapping(FromTy.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FPTOSI:
   case TargetOpcode::G_FPTOUI: {
     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping,
-                            getFPValueMapping(Ty.getSizeInBits())});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_SITOFP:
   case TargetOpcode::G_UITOFP: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getOperandsMapping(
-        {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = GPRValueMapping;
     break;
   }
   case TargetOpcode::G_FCONSTANT: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(Ty.getSizeInBits()), nullptr});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FCMP: {
@@ -343,15 +345,14 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     unsigned Size = Ty.getSizeInBits();
     assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
 
-    auto *FPRValueMapping = getFPValueMapping(Size);
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, FPRValueMapping, FPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
     break;
   }
   default:
     return getInvalidInstructionMapping();
   }
 
-  return getInstructionMapping(DefaultMappingID, /*Cost=*/1, OperandsMapping,
-                               NumOperands);
+  return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                               getOperandsMapping(OpdsMapping), NumOperands);
 }



More information about the llvm-commits mailing list