[llvm] 39f6d01 - [RISCV] Eliminate getVLENFactoredAmount and expose muladd [nfc] (#87881)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 8 10:24:31 PDT 2024
Author: Philip Reames
Date: 2024-04-08T10:24:27-07:00
New Revision: 39f6d015ddd69717ff1f9b817bce84d621d37731
URL: https://github.com/llvm/llvm-project/commit/39f6d015ddd69717ff1f9b817bce84d621d37731
DIFF: https://github.com/llvm/llvm-project/commit/39f6d015ddd69717ff1f9b817bce84d621d37731.diff
LOG: [RISCV] Eliminate getVLENFactoredAmount and expose muladd [nfc] (#87881)
This restructures the code to make the fact that most of
getVLENFactoredAmount is just a generic multiply w/immediate more
obvious and prepare for a couple of upcoming enhancements to this code.
Note that I plan to switch mulImm to early return, but decided I'd do
that as a separate commit to keep this diff readable.
---------
Co-authored-by: Luke Lau <luke_lau at icloud.com>
Added:
Modified:
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.h
llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 153f936326a78d..be63bc936ae8a1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2998,24 +2998,13 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
#undef CASE_WIDEOP_OPCODE_LMULS
#undef CASE_WIDEOP_OPCODE_COMMON
-void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
- MachineBasicBlock &MBB,
- MachineBasicBlock::iterator II,
- const DebugLoc &DL, Register DestReg,
- int64_t Amount,
- MachineInstr::MIFlag Flag) const {
- assert(Amount > 0 && "There is no need to get VLEN scaled value.");
- assert(Amount % 8 == 0 &&
- "Reserve the stack by the multiple of one vector size.");
-
+void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator II, const DebugLoc &DL,
+ Register DestReg, uint32_t Amount,
+ MachineInstr::MIFlag Flag) const {
MachineRegisterInfo &MRI = MF.getRegInfo();
- assert(isInt<32>(Amount / 8) &&
- "Expect the number of vector registers within 32-bits.");
- uint32_t NumOfVReg = Amount / 8;
-
- BuildMI(MBB, II, DL, get(RISCV::PseudoReadVLENB), DestReg).setMIFlag(Flag);
- if (llvm::has_single_bit<uint32_t>(NumOfVReg)) {
- uint32_t ShiftAmount = Log2_32(NumOfVReg);
+ if (llvm::has_single_bit<uint32_t>(Amount)) {
+ uint32_t ShiftAmount = Log2_32(Amount);
if (ShiftAmount == 0)
return;
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
@@ -3023,23 +3012,23 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
.addImm(ShiftAmount)
.setMIFlag(Flag);
} else if (STI.hasStdExtZba() &&
- ((NumOfVReg % 3 == 0 && isPowerOf2_64(NumOfVReg / 3)) ||
- (NumOfVReg % 5 == 0 && isPowerOf2_64(NumOfVReg / 5)) ||
- (NumOfVReg % 9 == 0 && isPowerOf2_64(NumOfVReg / 9)))) {
+ ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
+ (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
+ (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
unsigned Opc;
uint32_t ShiftAmount;
- if (NumOfVReg % 9 == 0) {
+ if (Amount % 9 == 0) {
Opc = RISCV::SH3ADD;
- ShiftAmount = Log2_64(NumOfVReg / 9);
- } else if (NumOfVReg % 5 == 0) {
+ ShiftAmount = Log2_64(Amount / 9);
+ } else if (Amount % 5 == 0) {
Opc = RISCV::SH2ADD;
- ShiftAmount = Log2_64(NumOfVReg / 5);
- } else if (NumOfVReg % 3 == 0) {
+ ShiftAmount = Log2_64(Amount / 5);
+ } else if (Amount % 3 == 0) {
Opc = RISCV::SH1ADD;
- ShiftAmount = Log2_64(NumOfVReg / 3);
+ ShiftAmount = Log2_64(Amount / 3);
} else {
- llvm_unreachable("Unexpected number of vregs");
+ llvm_unreachable("implied by if-clause");
}
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
@@ -3050,9 +3039,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
.addReg(DestReg, RegState::Kill)
.addReg(DestReg)
.setMIFlag(Flag);
- } else if (llvm::has_single_bit<uint32_t>(NumOfVReg - 1)) {
+ } else if (llvm::has_single_bit<uint32_t>(Amount - 1)) {
Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
- uint32_t ShiftAmount = Log2_32(NumOfVReg - 1);
+ uint32_t ShiftAmount = Log2_32(Amount - 1);
BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
.addReg(DestReg)
.addImm(ShiftAmount)
@@ -3061,9 +3050,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
.addReg(ScaledRegister, RegState::Kill)
.addReg(DestReg, RegState::Kill)
.setMIFlag(Flag);
- } else if (llvm::has_single_bit<uint32_t>(NumOfVReg + 1)) {
+ } else if (llvm::has_single_bit<uint32_t>(Amount + 1)) {
Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
- uint32_t ShiftAmount = Log2_32(NumOfVReg + 1);
+ uint32_t ShiftAmount = Log2_32(Amount + 1);
BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
.addReg(DestReg)
.addImm(ShiftAmount)
@@ -3074,7 +3063,7 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
.setMIFlag(Flag);
} else if (STI.hasStdExtM() || STI.hasStdExtZmmul()) {
Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass);
- movImm(MBB, II, DL, N, NumOfVReg, Flag);
+ movImm(MBB, II, DL, N, Amount, Flag);
BuildMI(MBB, II, DL, get(RISCV::MUL), DestReg)
.addReg(DestReg, RegState::Kill)
.addReg(N, RegState::Kill)
@@ -3082,14 +3071,14 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
} else {
Register Acc;
uint32_t PrevShiftAmount = 0;
- for (uint32_t ShiftAmount = 0; NumOfVReg >> ShiftAmount; ShiftAmount++) {
- if (NumOfVReg & (1U << ShiftAmount)) {
+ for (uint32_t ShiftAmount = 0; Amount >> ShiftAmount; ShiftAmount++) {
+ if (Amount & (1U << ShiftAmount)) {
if (ShiftAmount)
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
.addReg(DestReg, RegState::Kill)
.addImm(ShiftAmount - PrevShiftAmount)
.setMIFlag(Flag);
- if (NumOfVReg >> (ShiftAmount + 1)) {
+ if (Amount >> (ShiftAmount + 1)) {
// If we don't have an accmulator yet, create it and copy DestReg.
if (!Acc) {
Acc = MRI.createVirtualRegister(&RISCV::GPRRegClass);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 3470012d1518ea..81d9c9db783c02 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -229,10 +229,12 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
unsigned OpIdx,
const TargetRegisterInfo *TRI) const override;
- void getVLENFactoredAmount(
- MachineFunction &MF, MachineBasicBlock &MBB,
- MachineBasicBlock::iterator II, const DebugLoc &DL, Register DestReg,
- int64_t Amount, MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const;
+ /// Generate code to multiply the value in DestReg by Amt - handles all
+ /// the common optimizations for this idiom, and supports fallback for
+ /// subtargets which don't support multiply instructions.
+ void mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator II, const DebugLoc &DL,
+ Register DestReg, uint32_t Amt, MachineInstr::MIFlag Flag) const;
bool useMachineCombiner() const override { return true; }
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
index 11c3f2d57eb00f..713260b090e9cf 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -195,7 +195,16 @@ void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB,
Register ScratchReg = DestReg;
if (DestReg == SrcReg)
ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
- TII->getVLENFactoredAmount(MF, MBB, II, DL, ScratchReg, ScalableValue, Flag);
+
+ assert(ScalableValue > 0 && "There is no need to get VLEN scaled value.");
+ assert(ScalableValue % 8 == 0 &&
+ "Reserve the stack by the multiple of one vector size.");
+ assert(isInt<32>(ScalableValue / 8) &&
+ "Expect the number of vector registers within 32-bits.");
+ uint32_t NumOfVReg = ScalableValue / 8;
+ BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), ScratchReg)
+ .setMIFlag(Flag);
+ TII->mulImm(MF, MBB, II, DL, ScratchReg, NumOfVReg, Flag);
BuildMI(MBB, II, DL, TII->get(ScalableAdjOpc), DestReg)
.addReg(SrcReg).addReg(ScratchReg, RegState::Kill)
.setMIFlag(Flag);
More information about the llvm-commits
mailing list