[llvm] [RISCV][GISel] Sink getOperandsMapping call out of the switch in getInstrMapping. (PR #72326)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 15 11:21:38 PST 2023


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/72326

>From 55a05d8b51350e9ab305d687407bb8c7cf0af8e5 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 14 Nov 2023 15:31:07 -0800
Subject: [PATCH] [RISCV][GISel] Sink getOperandsMapping call out of the switch
 in getInstrMapping.

Use a SmallVector of ValueMapping * that we populate in the switch for
each register operand. The entry in the SmallVector defaults to nullptr
foreach operand so we don't need to do explicit nullptr in the switch.

Afer this we can generically fill in GPR for registers as a default case and
remove some opcodes from the switch.
---
 .../RISCV/GISel/RISCVRegisterBankInfo.cpp     | 112 +++++++++---------
 1 file changed, 55 insertions(+), 57 deletions(-)

diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index 0777b9962468935..92aeb5c06332ab4 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -199,12 +199,8 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   const ValueMapping *GPRValueMapping =
       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
                                           : RISCV::GPRB32Idx];
-  const ValueMapping *OperandsMapping = GPRValueMapping;
 
   switch (Opc) {
-  case TargetOpcode::G_INVOKE_REGION_START:
-    OperandsMapping = getOperandsMapping({});
-    break;
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_SUB:
   case TargetOpcode::G_SHL:
@@ -233,14 +229,37 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_ZEXT:
   case TargetOpcode::G_SEXTLOAD:
   case TargetOpcode::G_ZEXTLOAD:
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
+                                 NumOperands);
+  case TargetOpcode::G_FADD:
+  case TargetOpcode::G_FSUB:
+  case TargetOpcode::G_FMUL:
+  case TargetOpcode::G_FDIV:
+  case TargetOpcode::G_FABS:
+  case TargetOpcode::G_FNEG:
+  case TargetOpcode::G_FSQRT:
+  case TargetOpcode::G_FMAXNUM:
+  case TargetOpcode::G_FMINNUM: {
+    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                                 getFPValueMapping(Ty.getSizeInBits()),
+                                 NumOperands);
+  }
+  }
+
+  SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
+
+  switch (Opc) {
+  case TargetOpcode::G_INVOKE_REGION_START:
     break;
   case TargetOpcode::G_LOAD: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 loads on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
@@ -254,28 +273,25 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
                  // not, we would have had a bitcast before reaching that
                  // instruction.
                  return onlyUsesFP(UseMI, MRI, TRI);
-               })) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
-    }
+               }))
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
 
     break;
   }
   case TargetOpcode::G_STORE: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     // Use FPR64 for s64 stores on rv32.
     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
-      OperandsMapping =
-          getOperandsMapping({getFPValueMapping(64), GPRValueMapping});
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
       break;
     }
 
     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
     if (onlyDefinesFP(*DefMI, MRI, TRI)) {
-      OperandsMapping = getOperandsMapping(
-          {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
-    }
+      OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_CONSTANT:
@@ -283,76 +299,59 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_GLOBAL_VALUE:
   case TargetOpcode::G_JUMP_TABLE:
   case TargetOpcode::G_BRCOND:
-    OperandsMapping = getOperandsMapping({GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
     break;
   case TargetOpcode::G_BR:
-    OperandsMapping = getOperandsMapping({nullptr});
     break;
   case TargetOpcode::G_BRJT:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, nullptr, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
     break;
   case TargetOpcode::G_ICMP:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, GPRValueMapping, GPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
   case TargetOpcode::G_SEXT_INREG:
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping, GPRValueMapping, nullptr});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
     break;
   case TargetOpcode::G_SELECT:
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, GPRValueMapping, GPRValueMapping, GPRValueMapping});
-    break;
-  case TargetOpcode::G_FADD:
-  case TargetOpcode::G_FSUB:
-  case TargetOpcode::G_FMUL:
-  case TargetOpcode::G_FDIV:
-  case TargetOpcode::G_FNEG:
-  case TargetOpcode::G_FABS:
-  case TargetOpcode::G_FSQRT:
-  case TargetOpcode::G_FMAXNUM:
-  case TargetOpcode::G_FMINNUM: {
-    LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = GPRValueMapping;
+    OpdsMapping[2] = GPRValueMapping;
+    OpdsMapping[3] = GPRValueMapping;
     break;
-  }
   case TargetOpcode::G_FMA: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    const RegisterBankInfo::ValueMapping *FPValueMapping =
-        getFPValueMapping(Ty.getSizeInBits());
-    OperandsMapping = getOperandsMapping(
-        {FPValueMapping, FPValueMapping, FPValueMapping, FPValueMapping});
+    std::fill_n(OpdsMapping.begin(), 4, getFPValueMapping(Ty.getSizeInBits()));
     break;
   }
   case TargetOpcode::G_FPEXT:
   case TargetOpcode::G_FPTRUNC: {
     LLT ToTy = MRI.getType(MI.getOperand(0).getReg());
     LLT FromTy = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(ToTy.getSizeInBits()),
-                            getFPValueMapping(FromTy.getSizeInBits())});
+    OpdsMapping[0] = getFPValueMapping(ToTy.getSizeInBits());
+    OpdsMapping[1] = getFPValueMapping(FromTy.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FPTOSI:
   case TargetOpcode::G_FPTOUI: {
     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
-    OperandsMapping =
-        getOperandsMapping({GPRValueMapping,
-                            getFPValueMapping(Ty.getSizeInBits())});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_SITOFP:
   case TargetOpcode::G_UITOFP: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping = getOperandsMapping(
-        {getFPValueMapping(Ty.getSizeInBits()), GPRValueMapping});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
+    OpdsMapping[1] = GPRValueMapping;
     break;
   }
   case TargetOpcode::G_FCONSTANT: {
     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-    OperandsMapping =
-        getOperandsMapping({getFPValueMapping(Ty.getSizeInBits()), nullptr});
+    OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
     break;
   }
   case TargetOpcode::G_FCMP: {
@@ -361,15 +360,14 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     unsigned Size = Ty.getSizeInBits();
     assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
 
-    auto *FPRValueMapping = getFPValueMapping(Size);
-    OperandsMapping = getOperandsMapping(
-        {GPRValueMapping, nullptr, FPRValueMapping, FPRValueMapping});
+    OpdsMapping[0] = GPRValueMapping;
+    OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
     break;
   }
   default:
     return getInvalidInstructionMapping();
   }
 
-  return getInstructionMapping(DefaultMappingID, /*Cost=*/1, OperandsMapping,
-                               NumOperands);
+  return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
+                               getOperandsMapping(OpdsMapping), NumOperands);
 }



More information about the llvm-commits mailing list