[llvm] [RISCV] Eliminate getVLENFactoredAmount and expose muladd [nfc] (PR #87881)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 6 09:48:53 PDT 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/87881

This restructures the code to make the fact that most of getVLENFactoredAmount is just a generic multiply w/immediate more obvious and prepare for a couple of upcoming enhancements to this code.

Note that I plan to switch mulImm to early return, but decided I'd do that as a separate commit to keep this diff readable.

>From 551aaed2bde8fafe6f1eb698a6cf5d01628251e8 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 5 Apr 2024 14:59:42 -0700
Subject: [PATCH] [RISCV] Eliminate getVLENFactoredAmount and expose muladd
 [nfc]

This restructures the code to make the fact that most of getVLENFactoredAmount
is just a generic multiply w/immediate more obvious and prepare for a couple
of upcoming enhancements to this code.

Note that I plan to switch mulImm to early return, but decided I'd do that
as a separate commit to keep this diff readable.
---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp    | 59 +++++++++------------
 llvm/lib/Target/RISCV/RISCVInstrInfo.h      | 10 ++--
 llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp | 10 +++-
 3 files changed, 39 insertions(+), 40 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5582de51b17d19..f612b582904a65 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -3052,24 +3052,13 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
 #undef CASE_WIDEOP_OPCODE_LMULS
 #undef CASE_WIDEOP_OPCODE_COMMON
 
-void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
-                                           MachineBasicBlock &MBB,
-                                           MachineBasicBlock::iterator II,
-                                           const DebugLoc &DL, Register DestReg,
-                                           int64_t Amount,
-                                           MachineInstr::MIFlag Flag) const {
-  assert(Amount > 0 && "There is no need to get VLEN scaled value.");
-  assert(Amount % 8 == 0 &&
-         "Reserve the stack by the multiple of one vector size.");
-
+void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
+                            MachineBasicBlock::iterator II, const DebugLoc &DL,
+                            Register DestReg, int32_t Amount,
+                            MachineInstr::MIFlag Flag) const {
   MachineRegisterInfo &MRI = MF.getRegInfo();
-  assert(isInt<32>(Amount / 8) &&
-         "Expect the number of vector registers within 32-bits.");
-  uint32_t NumOfVReg = Amount / 8;
-
-  BuildMI(MBB, II, DL, get(RISCV::PseudoReadVLENB), DestReg).setMIFlag(Flag);
-  if (llvm::has_single_bit<uint32_t>(NumOfVReg)) {
-    uint32_t ShiftAmount = Log2_32(NumOfVReg);
+  if (llvm::has_single_bit<uint32_t>(Amount)) {
+    uint32_t ShiftAmount = Log2_32(Amount);
     if (ShiftAmount == 0)
       return;
     BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
@@ -3077,23 +3066,23 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
         .addImm(ShiftAmount)
         .setMIFlag(Flag);
   } else if (STI.hasStdExtZba() &&
-             ((NumOfVReg % 3 == 0 && isPowerOf2_64(NumOfVReg / 3)) ||
-              (NumOfVReg % 5 == 0 && isPowerOf2_64(NumOfVReg / 5)) ||
-              (NumOfVReg % 9 == 0 && isPowerOf2_64(NumOfVReg / 9)))) {
+             ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
+              (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
+              (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
     // We can use Zba SHXADD+SLLI instructions for multiply in some cases.
     unsigned Opc;
     uint32_t ShiftAmount;
-    if (NumOfVReg % 9 == 0) {
+    if (Amount % 9 == 0) {
       Opc = RISCV::SH3ADD;
-      ShiftAmount = Log2_64(NumOfVReg / 9);
-    } else if (NumOfVReg % 5 == 0) {
+      ShiftAmount = Log2_64(Amount / 9);
+    } else if (Amount % 5 == 0) {
       Opc = RISCV::SH2ADD;
-      ShiftAmount = Log2_64(NumOfVReg / 5);
-    } else if (NumOfVReg % 3 == 0) {
+      ShiftAmount = Log2_64(Amount / 5);
+    } else if (Amount % 3 == 0) {
       Opc = RISCV::SH1ADD;
-      ShiftAmount = Log2_64(NumOfVReg / 3);
+      ShiftAmount = Log2_64(Amount / 3);
     } else {
-      llvm_unreachable("Unexpected number of vregs");
+      llvm_unreachable("impied by if-clause");
     }
     if (ShiftAmount)
       BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
@@ -3104,9 +3093,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
         .addReg(DestReg, RegState::Kill)
         .addReg(DestReg)
         .setMIFlag(Flag);
-  } else if (llvm::has_single_bit<uint32_t>(NumOfVReg - 1)) {
+  } else if (llvm::has_single_bit<uint32_t>(Amount - 1)) {
     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-    uint32_t ShiftAmount = Log2_32(NumOfVReg - 1);
+    uint32_t ShiftAmount = Log2_32(Amount - 1);
     BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
         .addReg(DestReg)
         .addImm(ShiftAmount)
@@ -3115,9 +3104,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
         .addReg(ScaledRegister, RegState::Kill)
         .addReg(DestReg, RegState::Kill)
         .setMIFlag(Flag);
-  } else if (llvm::has_single_bit<uint32_t>(NumOfVReg + 1)) {
+  } else if (llvm::has_single_bit<uint32_t>(Amount + 1)) {
     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-    uint32_t ShiftAmount = Log2_32(NumOfVReg + 1);
+    uint32_t ShiftAmount = Log2_32(Amount + 1);
     BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
         .addReg(DestReg)
         .addImm(ShiftAmount)
@@ -3128,7 +3117,7 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
         .setMIFlag(Flag);
   } else if (STI.hasStdExtM() || STI.hasStdExtZmmul()) {
     Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-    movImm(MBB, II, DL, N, NumOfVReg, Flag);
+    movImm(MBB, II, DL, N, Amount, Flag);
     BuildMI(MBB, II, DL, get(RISCV::MUL), DestReg)
         .addReg(DestReg, RegState::Kill)
         .addReg(N, RegState::Kill)
@@ -3136,14 +3125,14 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
   } else {
     Register Acc;
     uint32_t PrevShiftAmount = 0;
-    for (uint32_t ShiftAmount = 0; NumOfVReg >> ShiftAmount; ShiftAmount++) {
-      if (NumOfVReg & (1U << ShiftAmount)) {
+    for (uint32_t ShiftAmount = 0; Amount >> ShiftAmount; ShiftAmount++) {
+      if (Amount & (1U << ShiftAmount)) {
         if (ShiftAmount)
           BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
               .addReg(DestReg, RegState::Kill)
               .addImm(ShiftAmount - PrevShiftAmount)
               .setMIFlag(Flag);
-        if (NumOfVReg >> (ShiftAmount + 1)) {
+        if (Amount >> (ShiftAmount + 1)) {
           // If we don't have an accmulator yet, create it and copy DestReg.
           if (!Acc) {
             Acc = MRI.createVirtualRegister(&RISCV::GPRRegClass);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index dd049fca059719..0a74b92303da7f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -229,10 +229,12 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
                           unsigned OpIdx,
                           const TargetRegisterInfo *TRI) const override;
 
-  void getVLENFactoredAmount(
-      MachineFunction &MF, MachineBasicBlock &MBB,
-      MachineBasicBlock::iterator II, const DebugLoc &DL, Register DestReg,
-      int64_t Amount, MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const;
+  // Generate code to multiply the value in DestReg by Amount - handles
+  // all the common optimizations for this idiom, and supports fallback for
+  // subtargets which don't support multiply instructions.
+  void mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
+              MachineBasicBlock::iterator II, const DebugLoc &DL,
+              Register DestReg, int32_t Amt, MachineInstr::MIFlag Flag) const;
 
   bool useMachineCombiner() const override { return true; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
index 11c3f2d57eb00f..8f0fdb27d2926f 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -195,7 +195,15 @@ void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB,
     Register ScratchReg = DestReg;
     if (DestReg == SrcReg)
       ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
-    TII->getVLENFactoredAmount(MF, MBB, II, DL, ScratchReg, ScalableValue, Flag);
+
+    assert(ScalableValue > 0 && "There is no need to get VLEN scaled value.");
+    assert(ScalableValue % 8 == 0 &&
+           "Reserve the stack by the multiple of one vector size.");
+    assert(isInt<32>(ScalableValue / 8) &&
+           "Expect the number of vector registers within 32-bits.");
+    uint32_t NumOfVReg = ScalableValue / 8;
+    BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), ScratchReg).setMIFlag(Flag);
+    TII->mulImm(MF, MBB, II, DL, ScratchReg, NumOfVReg, Flag);
     BuildMI(MBB, II, DL, TII->get(ScalableAdjOpc), DestReg)
       .addReg(SrcReg).addReg(ScratchReg, RegState::Kill)
       .setMIFlag(Flag);



More information about the llvm-commits mailing list