[llvm] [RISCV] Introduce VLOptimizer pass (PR #108640)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 11 06:15:48 PDT 2024
================
@@ -0,0 +1,833 @@
+//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This pass reduces the VL where possible at the MI level, before VSETVLI
+// instructions are inserted.
+//
+// The purpose of this optimization is to make the VL argument, for instructions
+// that have a VL argument, as small as possible. This is implemented by
+// visiting each instruction in reverse order and checking that if it has a VL
+// argument, whether the VL can be reduced.
+//
+//===---------------------------------------------------------------------===//
+
+#include "RISCV.h"
+#include "RISCVMachineFunctionInfo.h"
+#include "RISCVSubtarget.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-vl-optimizer"
+#define PASS_NAME "RISC-V VL Optimizer"
+
+namespace {
+
+class RISCVVLOptimizer : public MachineFunctionPass {
+ const MachineRegisterInfo *MRI;
+ const MachineDominatorTree *MDT;
+
+public:
+ static char ID;
+
+ RISCVVLOptimizer() : MachineFunctionPass(ID) {}
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<MachineDominatorTreeWrapperPass>();
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+
+ StringRef getPassName() const override { return PASS_NAME; }
+
+private:
+ bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+ bool tryReduceVL(MachineInstr &MI);
+ bool isCandidate(const MachineInstr &MI) const;
+};
+
+} // end anonymous namespace
+
+char RISCVVLOptimizer::ID = 0;
+INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
+
+FunctionPass *llvm::createRISCVVLOptimizerPass() {
+ return new RISCVVLOptimizer();
+}
+
+/// Return true if R is a physical or virtual vector register, false otherwise.
+static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
+ if (R.isPhysical())
+ return RISCV::VRRegClass.contains(R);
+ const TargetRegisterClass *RC = MRI->getRegClass(R);
+ return RISCVRI::isVRegClass(RC->TSFlags);
+}
+
+/// Represents the EMUL and EEW of a MachineOperand.
+struct OperandInfo {
+ enum class State {
+ Unknown,
+ Known,
+ } S;
+
+ // Represent as 1,2,4,8, ... and fractional indicator. This is because
+ // EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
+ // For example, a mask operand can have an EMUL less than MF8.
+ std::optional<std::pair<unsigned, bool>> EMUL;
+
+ unsigned Log2EEW;
+
+ OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
+ : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {
+ }
+
+ OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
+ : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}
+
+ OperandInfo() : S(State::Unknown) {}
+
+ bool isUnknown() const { return S == State::Unknown; }
+ bool isKnown() const { return S == State::Known; }
+
+ static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
+ assert(A.isKnown() && B.isKnown() && "Both operands must be known");
+
+ return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first &&
+ A.EMUL->second == B.EMUL->second;
+ }
+
+ void print(raw_ostream &OS) const {
+ if (isUnknown()) {
+ OS << "Unknown";
+ return;
+ }
+ assert(EMUL && "Expected EMUL to have value");
+ OS << "EMUL: ";
+ if (EMUL->second)
+ OS << "m";
+ OS << "f" << EMUL->first;
+ OS << ", EEW: " << (1 << Log2EEW);
+ }
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
+ OI.print(OS);
+ return OS;
+}
+
+namespace llvm {
+namespace RISCVVType {
+/// Return the RISCVII::VLMUL that is two times VLMul.
+/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8.
+static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) {
+ switch (VLMul) {
+ case RISCVII::VLMUL::LMUL_F8:
+ return RISCVII::VLMUL::LMUL_F4;
+ case RISCVII::VLMUL::LMUL_F4:
+ return RISCVII::VLMUL::LMUL_F2;
+ case RISCVII::VLMUL::LMUL_F2:
+ return RISCVII::VLMUL::LMUL_1;
+ case RISCVII::VLMUL::LMUL_1:
+ return RISCVII::VLMUL::LMUL_2;
+ case RISCVII::VLMUL::LMUL_2:
+ return RISCVII::VLMUL::LMUL_4;
+ case RISCVII::VLMUL::LMUL_4:
+ return RISCVII::VLMUL::LMUL_8;
+ case RISCVII::VLMUL::LMUL_8:
+ default:
+ llvm_unreachable("Could not multiply VLMul by 2");
+ }
+}
+
+/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
+/// SEW are from the TSFlags of MI.
+static std::pair<unsigned, bool>
+getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
+ RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
+ auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+ unsigned MISEW = 1 << MILog2SEW;
+
+ unsigned EEW = 1 << Log2EEW;
+ // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
+ // to put fraction in simplest form.
+ unsigned Num = EEW, Denom = MISEW;
+ int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
+ : std::gcd(Num * MILMUL, Denom);
+ Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
+ Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
+ return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
+}
+} // end namespace RISCVVType
+} // end namespace llvm
+
+/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
+/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
+static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
+ const MachineInstr &MI,
+ const MachineOperand &MO) {
+ RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+ if (MO.getOperandNo() == 0)
+ return OperandInfo(MIVLMul, MILog2SEW);
+
+ unsigned MISEW = 1 << MILog2SEW;
+ unsigned EEW = MISEW / Factor;
+ unsigned Log2EEW = Log2_32(EEW);
+
+ return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
+ Log2EEW);
+}
+
+/// Check whether MO is a mask operand of MI.
+static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
+ const MachineRegisterInfo *MRI) {
+
+ if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
+ return false;
+
+ const MCInstrDesc &Desc = MI.getDesc();
+ return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
+}
+
+/// Return the OperandInfo for MO, which is an operand of MI.
+static OperandInfo getOperandInfo(const MachineInstr &MI,
+ const MachineOperand &MO,
+ const MachineRegisterInfo *MRI) {
+ const RISCVVPseudosTable::PseudoInfo *RVV =
+ RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
+ assert(RVV && "Could not find MI in PseudoTable");
+
+ // MI has a VLMUL and SEW associated with it. The RVV specification defines
+ // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and
+ // MI.SEW.
+ RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+ const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
+
+ // We bail out early for instructions that have passthru with non NoRegister,
+ // which means they are using TU policy. We are not interested in these
+ // since they must preserve the entire register content.
+ if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() &&
+ (MO.getReg() != RISCV::NoRegister))
+ return {};
+
+ bool IsMODef = MO.getOperandNo() == 0;
+
+ // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL
+ if (isMaskOperand(MI, MO, MRI))
+ return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+
+ // switch against BaseInstr to reduce number of cases that need to be
+ // considered.
+ switch (RVV->BaseInstr) {
+
+ // 6. Configuration-Setting Instructions
+ // Configuration setting instructions do not read or write vector registers
+ case RISCV::VSETIVLI:
+ case RISCV::VSETVL:
+ case RISCV::VSETVLI:
+ llvm_unreachable("Configuration setting instructions do not read or write "
+ "vector registers");
+
+ // 11. Vector Integer Arithmetic Instructions
+ // 11.1. Vector Single-Width Integer Add and Subtract
+ case RISCV::VADD_VI:
+ case RISCV::VADD_VV:
+ case RISCV::VADD_VX:
+ case RISCV::VSUB_VV:
+ case RISCV::VSUB_VX:
+ case RISCV::VRSUB_VI:
+ case RISCV::VRSUB_VX:
+ // 11.5. Vector Bitwise Logical Instructions
+ // 11.6. Vector Single-Width Shift Instructions
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VAND_VI:
+ case RISCV::VAND_VV:
+ case RISCV::VAND_VX:
+ case RISCV::VOR_VI:
+ case RISCV::VOR_VV:
+ case RISCV::VOR_VX:
+ case RISCV::VXOR_VI:
+ case RISCV::VXOR_VV:
+ case RISCV::VXOR_VX:
+ case RISCV::VSLL_VI:
+ case RISCV::VSLL_VV:
+ case RISCV::VSLL_VX:
+ case RISCV::VSRL_VI:
+ case RISCV::VSRL_VV:
+ case RISCV::VSRL_VX:
+ case RISCV::VSRA_VI:
+ case RISCV::VSRA_VV:
+ case RISCV::VSRA_VX:
+ // 11.9. Vector Integer Min/Max Instructions
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VMINU_VV:
+ case RISCV::VMINU_VX:
+ case RISCV::VMIN_VV:
+ case RISCV::VMIN_VX:
+ case RISCV::VMAXU_VV:
+ case RISCV::VMAXU_VX:
+ case RISCV::VMAX_VV:
+ case RISCV::VMAX_VX:
+ // 11.10. Vector Single-Width Integer Multiply Instructions
+ // Source and Dest EEW=SEW and EMUL=LMUL.
+ case RISCV::VMUL_VV:
+ case RISCV::VMUL_VX:
+ case RISCV::VMULH_VV:
+ case RISCV::VMULH_VX:
+ case RISCV::VMULHU_VV:
+ case RISCV::VMULHU_VX:
+ case RISCV::VMULHSU_VV:
+ case RISCV::VMULHSU_VX:
+ // 11.11. Vector Integer Divide Instructions
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VDIVU_VV:
+ case RISCV::VDIVU_VX:
+ case RISCV::VDIV_VV:
+ case RISCV::VDIV_VX:
+ case RISCV::VREMU_VV:
+ case RISCV::VREMU_VX:
+ case RISCV::VREM_VV:
+ case RISCV::VREM_VX:
+ // 11.13. Vector Single-Width Integer Multiply-Add Instructions
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VMACC_VV:
+ case RISCV::VMACC_VX:
+ case RISCV::VNMSAC_VV:
+ case RISCV::VNMSAC_VX:
+ case RISCV::VMADD_VV:
+ case RISCV::VMADD_VX:
+ case RISCV::VNMSUB_VV:
+ case RISCV::VNMSUB_VX:
+ // 11.15. Vector Integer Merge Instructions
+ // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
+ // (EEW/SEW)*LMUL. Mask operand is handled before this switch.
+ case RISCV::VMERGE_VIM:
+ case RISCV::VMERGE_VVM:
+ case RISCV::VMERGE_VXM:
+ // 11.16. Vector Integer Move Instructions
+ // 12. Vector Fixed-Point Arithmetic Instructions
+ // 12.1. Vector Single-Width Saturating Add and Subtract
+ // 12.2. Vector Single-Width Averaging Add and Subtract
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VMV_V_I:
+ case RISCV::VMV_V_V:
+ case RISCV::VMV_V_X:
+ case RISCV::VSADDU_VI:
+ case RISCV::VSADDU_VV:
+ case RISCV::VSADDU_VX:
+ case RISCV::VSADD_VI:
+ case RISCV::VSADD_VV:
+ case RISCV::VSADD_VX:
+ case RISCV::VSSUBU_VV:
+ case RISCV::VSSUBU_VX:
+ case RISCV::VSSUB_VV:
+ case RISCV::VSSUB_VX:
+ case RISCV::VAADDU_VV:
+ case RISCV::VAADDU_VX:
+ case RISCV::VAADD_VV:
+ case RISCV::VAADD_VX:
+ case RISCV::VASUBU_VV:
+ case RISCV::VASUBU_VX:
+ case RISCV::VASUB_VV:
+ case RISCV::VASUB_VX:
+ // 12.4. Vector Single-Width Scaling Shift Instructions
+ // EEW=SEW. EMUL=LMUL.
+ case RISCV::VSSRL_VI:
+ case RISCV::VSSRL_VV:
+ case RISCV::VSSRL_VX:
+ case RISCV::VSSRA_VI:
+ case RISCV::VSSRA_VV:
+ case RISCV::VSSRA_VX:
+ // 16. Vector Permutation Instructions
+ // 16.1. Integer Scalar Move Instructions
+ // 16.2. Floating-Point Scalar Move Instructions
+ // EMUL=LMUL. EEW=SEW.
+ case RISCV::VMV_X_S:
+ case RISCV::VMV_S_X:
+ case RISCV::VFMV_F_S:
+ case RISCV::VFMV_S_F:
+ // 16.3. Vector Slide Instructions
+ // EMUL=LMUL. EEW=SEW.
+ case RISCV::VSLIDEUP_VI:
+ case RISCV::VSLIDEUP_VX:
+ case RISCV::VSLIDEDOWN_VI:
+ case RISCV::VSLIDEDOWN_VX:
+ case RISCV::VSLIDE1UP_VX:
+ case RISCV::VFSLIDE1UP_VF:
+ case RISCV::VSLIDE1DOWN_VX:
+ case RISCV::VFSLIDE1DOWN_VF:
+ // 16.4. Vector Register Gather Instructions
+ // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1.
+ case RISCV::VRGATHER_VI:
+ case RISCV::VRGATHER_VV:
+ case RISCV::VRGATHER_VX:
+ // 16.5. Vector Compress Instruction
+ // EMUL=LMUL. EEW=SEW.
+ case RISCV::VCOMPRESS_VM:
+ return OperandInfo(MIVLMul, MILog2SEW);
+
+ // 11.2. Vector Widening Integer Add/Subtract
+ // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL.
+ case RISCV::VWADDU_VV:
+ case RISCV::VWADDU_VX:
+ case RISCV::VWSUBU_VV:
+ case RISCV::VWSUBU_VX:
+ case RISCV::VWADD_VV:
+ case RISCV::VWADD_VX:
+ case RISCV::VWSUB_VV:
+ case RISCV::VWSUB_VX:
+ case RISCV::VWSLL_VI:
+ // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
+ // Destination EEW=2*SEW and EMUL=2*EMUL. Source EEW=SEW and EMUL=LMUL.
+ case RISCV::VSMUL_VV:
+ case RISCV::VSMUL_VX:
----------------
lukel97 wrote:
Are these not single width? I thought the multiply happened in SEW*2 but the destination is still SEW
https://github.com/llvm/llvm-project/pull/108640
More information about the llvm-commits
mailing list