[llvm] [RISCV] Introduce VLOptimizer pass (PR #108640)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 06:29:31 PDT 2024


================
@@ -0,0 +1,833 @@
+//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This pass reduces the VL where possible at the MI level, before VSETVLI
+// instructions are inserted.
+//
+// The purpose of this optimization is to make the VL argument, for instructions
+// that have a VL argument, as small as possible. This is implemented by
+// visiting each instruction in reverse order and checking that if it has a VL
+// argument, whether the VL can be reduced.
+//
+//===---------------------------------------------------------------------===//
+
+#include "RISCV.h"
+#include "RISCVMachineFunctionInfo.h"
+#include "RISCVSubtarget.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-vl-optimizer"
+#define PASS_NAME "RISC-V VL Optimizer"
+
+namespace {
+
+class RISCVVLOptimizer : public MachineFunctionPass {
+  const MachineRegisterInfo *MRI;
+  const MachineDominatorTree *MDT;
+
+public:
+  static char ID;
+
+  RISCVVLOptimizer() : MachineFunctionPass(ID) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    AU.addRequired<MachineDominatorTreeWrapperPass>();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override { return PASS_NAME; }
+
+private:
+  bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI);
+  bool tryReduceVL(MachineInstr &MI);
+  bool isCandidate(const MachineInstr &MI) const;
+};
+
+} // end anonymous namespace
+
+char RISCVVLOptimizer::ID = 0;
+INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
+
+FunctionPass *llvm::createRISCVVLOptimizerPass() {
+  return new RISCVVLOptimizer();
+}
+
+/// Return true if R is a physical or virtual vector register, false otherwise.
+static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
+  if (R.isPhysical())
+    return RISCV::VRRegClass.contains(R);
+  const TargetRegisterClass *RC = MRI->getRegClass(R);
+  return RISCVRI::isVRegClass(RC->TSFlags);
+}
+
+/// Represents the EMUL and EEW of a MachineOperand.
+struct OperandInfo {
+  enum class State {
+    Unknown,
+    Known,
+  } S;
+
+  // Represent as 1,2,4,8, ... and fractional indicator. This is because
+  // EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
+  // For example, a mask operand can have an EMUL less than MF8.
+  std::optional<std::pair<unsigned, bool>> EMUL;
+
+  unsigned Log2EEW;
+
+  OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
+      : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {
+  }
+
+  OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
+      : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}
+
+  OperandInfo() : S(State::Unknown) {}
+
+  bool isUnknown() const { return S == State::Unknown; }
+  bool isKnown() const { return S == State::Known; }
+
+  static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
+    assert(A.isKnown() && B.isKnown() && "Both operands must be known");
+
+    return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first &&
+           A.EMUL->second == B.EMUL->second;
+  }
+
+  void print(raw_ostream &OS) const {
+    if (isUnknown()) {
+      OS << "Unknown";
+      return;
+    }
+    assert(EMUL && "Expected EMUL to have value");
+    OS << "EMUL: ";
+    if (EMUL->second)
+      OS << "m";
+    OS << "f" << EMUL->first;
+    OS << ", EEW: " << (1 << Log2EEW);
+  }
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
+  OI.print(OS);
+  return OS;
+}
+
+namespace llvm {
+namespace RISCVVType {
+/// Return the RISCVII::VLMUL that is two times VLMul.
+/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8.
+static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) {
+  switch (VLMul) {
+  case RISCVII::VLMUL::LMUL_F8:
+    return RISCVII::VLMUL::LMUL_F4;
+  case RISCVII::VLMUL::LMUL_F4:
+    return RISCVII::VLMUL::LMUL_F2;
+  case RISCVII::VLMUL::LMUL_F2:
+    return RISCVII::VLMUL::LMUL_1;
+  case RISCVII::VLMUL::LMUL_1:
+    return RISCVII::VLMUL::LMUL_2;
+  case RISCVII::VLMUL::LMUL_2:
+    return RISCVII::VLMUL::LMUL_4;
+  case RISCVII::VLMUL::LMUL_4:
+    return RISCVII::VLMUL::LMUL_8;
+  case RISCVII::VLMUL::LMUL_8:
+  default:
+    llvm_unreachable("Could not multiply VLMul by 2");
+  }
+}
+
+/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
+/// SEW are from the TSFlags of MI.
+static std::pair<unsigned, bool>
+getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
+  RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
+  auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
+  unsigned MILog2SEW =
+      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+  unsigned MISEW = 1 << MILog2SEW;
+
+  unsigned EEW = 1 << Log2EEW;
+  // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
+  // to put fraction in simplest form.
+  unsigned Num = EEW, Denom = MISEW;
+  int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
+                               : std::gcd(Num * MILMUL, Denom);
+  Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
+  Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
+  return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
+}
+} // end namespace RISCVVType
+} // end namespace llvm
+
+/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
+/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
+static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
+                                                  const MachineInstr &MI,
+                                                  const MachineOperand &MO) {
+  RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+  unsigned MILog2SEW =
+      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+  if (MO.getOperandNo() == 0)
+    return OperandInfo(MIVLMul, MILog2SEW);
+
+  unsigned MISEW = 1 << MILog2SEW;
+  unsigned EEW = MISEW / Factor;
+  unsigned Log2EEW = Log2_32(EEW);
+
+  return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
+                     Log2EEW);
+}
+
+/// Check whether MO is a mask operand of MI.
+static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
+                          const MachineRegisterInfo *MRI) {
+
+  if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
+    return false;
+
+  const MCInstrDesc &Desc = MI.getDesc();
+  return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
+}
+
+/// Return the OperandInfo for MO, which is an operand of MI.
+static OperandInfo getOperandInfo(const MachineInstr &MI,
+                                  const MachineOperand &MO,
+                                  const MachineRegisterInfo *MRI) {
+  const RISCVVPseudosTable::PseudoInfo *RVV =
+      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
+  assert(RVV && "Could not find MI in PseudoTable");
+
+  // MI has a VLMUL and SEW associated with it. The RVV specification defines
+  // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and
+  // MI.SEW.
+  RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+  unsigned MILog2SEW =
+      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+  const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
+
+  // We bail out early for instructions that have passthru with non NoRegister,
+  // which means they are using TU policy. We are not interested in these
+  // since they must preserve the entire register content.
+  if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() &&
+      (MO.getReg() != RISCV::NoRegister))
+    return {};
+
+  bool IsMODef = MO.getOperandNo() == 0;
+
+  // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL
+  if (isMaskOperand(MI, MO, MRI))
+    return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
+
+  // switch against BaseInstr to reduce number of cases that need to be
+  // considered.
+  switch (RVV->BaseInstr) {
+
+  // 6. Configuration-Setting Instructions
+  // Configuration setting instructions do not read or write vector registers
+  case RISCV::VSETIVLI:
+  case RISCV::VSETVL:
+  case RISCV::VSETVLI:
+    llvm_unreachable("Configuration setting instructions do not read or write "
+                     "vector registers");
+
+  // 11. Vector Integer Arithmetic Instructions
+  // 11.1. Vector Single-Width Integer Add and Subtract
+  case RISCV::VADD_VI:
+  case RISCV::VADD_VV:
+  case RISCV::VADD_VX:
+  case RISCV::VSUB_VV:
+  case RISCV::VSUB_VX:
+  case RISCV::VRSUB_VI:
+  case RISCV::VRSUB_VX:
+  // 11.5. Vector Bitwise Logical Instructions
+  // 11.6. Vector Single-Width Shift Instructions
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VAND_VI:
+  case RISCV::VAND_VV:
+  case RISCV::VAND_VX:
+  case RISCV::VOR_VI:
+  case RISCV::VOR_VV:
+  case RISCV::VOR_VX:
+  case RISCV::VXOR_VI:
+  case RISCV::VXOR_VV:
+  case RISCV::VXOR_VX:
+  case RISCV::VSLL_VI:
+  case RISCV::VSLL_VV:
+  case RISCV::VSLL_VX:
+  case RISCV::VSRL_VI:
+  case RISCV::VSRL_VV:
+  case RISCV::VSRL_VX:
+  case RISCV::VSRA_VI:
+  case RISCV::VSRA_VV:
+  case RISCV::VSRA_VX:
+  // 11.9. Vector Integer Min/Max Instructions
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VMINU_VV:
+  case RISCV::VMINU_VX:
+  case RISCV::VMIN_VV:
+  case RISCV::VMIN_VX:
+  case RISCV::VMAXU_VV:
+  case RISCV::VMAXU_VX:
+  case RISCV::VMAX_VV:
+  case RISCV::VMAX_VX:
+  // 11.10. Vector Single-Width Integer Multiply Instructions
+  // Source and Dest EEW=SEW and EMUL=LMUL.
+  case RISCV::VMUL_VV:
+  case RISCV::VMUL_VX:
+  case RISCV::VMULH_VV:
+  case RISCV::VMULH_VX:
+  case RISCV::VMULHU_VV:
+  case RISCV::VMULHU_VX:
+  case RISCV::VMULHSU_VV:
+  case RISCV::VMULHSU_VX:
+  // 11.11. Vector Integer Divide Instructions
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VDIVU_VV:
+  case RISCV::VDIVU_VX:
+  case RISCV::VDIV_VV:
+  case RISCV::VDIV_VX:
+  case RISCV::VREMU_VV:
+  case RISCV::VREMU_VX:
+  case RISCV::VREM_VV:
+  case RISCV::VREM_VX:
+  // 11.13. Vector Single-Width Integer Multiply-Add Instructions
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VMACC_VV:
+  case RISCV::VMACC_VX:
+  case RISCV::VNMSAC_VV:
+  case RISCV::VNMSAC_VX:
+  case RISCV::VMADD_VV:
+  case RISCV::VMADD_VX:
+  case RISCV::VNMSUB_VV:
+  case RISCV::VNMSUB_VX:
+  // 11.15. Vector Integer Merge Instructions
+  // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
+  // (EEW/SEW)*LMUL. Mask operand is handled before this switch.
+  case RISCV::VMERGE_VIM:
+  case RISCV::VMERGE_VVM:
+  case RISCV::VMERGE_VXM:
+  // 11.16. Vector Integer Move Instructions
+  // 12. Vector Fixed-Point Arithmetic Instructions
+  // 12.1. Vector Single-Width Saturating Add and Subtract
+  // 12.2. Vector Single-Width Averaging Add and Subtract
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VMV_V_I:
+  case RISCV::VMV_V_V:
+  case RISCV::VMV_V_X:
+  case RISCV::VSADDU_VI:
+  case RISCV::VSADDU_VV:
+  case RISCV::VSADDU_VX:
+  case RISCV::VSADD_VI:
+  case RISCV::VSADD_VV:
+  case RISCV::VSADD_VX:
+  case RISCV::VSSUBU_VV:
+  case RISCV::VSSUBU_VX:
+  case RISCV::VSSUB_VV:
+  case RISCV::VSSUB_VX:
+  case RISCV::VAADDU_VV:
+  case RISCV::VAADDU_VX:
+  case RISCV::VAADD_VV:
+  case RISCV::VAADD_VX:
+  case RISCV::VASUBU_VV:
+  case RISCV::VASUBU_VX:
+  case RISCV::VASUB_VV:
+  case RISCV::VASUB_VX:
+  // 12.4. Vector Single-Width Scaling Shift Instructions
+  // EEW=SEW. EMUL=LMUL.
+  case RISCV::VSSRL_VI:
+  case RISCV::VSSRL_VV:
+  case RISCV::VSSRL_VX:
+  case RISCV::VSSRA_VI:
+  case RISCV::VSSRA_VV:
+  case RISCV::VSSRA_VX:
+  // 16. Vector Permutation Instructions
+  // 16.1. Integer Scalar Move Instructions
+  // 16.2. Floating-Point Scalar Move Instructions
+  // EMUL=LMUL. EEW=SEW.
+  case RISCV::VMV_X_S:
+  case RISCV::VMV_S_X:
+  case RISCV::VFMV_F_S:
+  case RISCV::VFMV_S_F:
+  // 16.3. Vector Slide Instructions
+  // EMUL=LMUL. EEW=SEW.
+  case RISCV::VSLIDEUP_VI:
+  case RISCV::VSLIDEUP_VX:
+  case RISCV::VSLIDEDOWN_VI:
+  case RISCV::VSLIDEDOWN_VX:
+  case RISCV::VSLIDE1UP_VX:
+  case RISCV::VFSLIDE1UP_VF:
+  case RISCV::VSLIDE1DOWN_VX:
+  case RISCV::VFSLIDE1DOWN_VF:
+  // 16.4. Vector Register Gather Instructions
+  // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1.
+  case RISCV::VRGATHER_VI:
+  case RISCV::VRGATHER_VV:
+  case RISCV::VRGATHER_VX:
+  // 16.5. Vector Compress Instruction
+  // EMUL=LMUL. EEW=SEW.
+  case RISCV::VCOMPRESS_VM:
+    return OperandInfo(MIVLMul, MILog2SEW);
+
+  // 11.2. Vector Widening Integer Add/Subtract
+  // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL.
+  case RISCV::VWADDU_VV:
+  case RISCV::VWADDU_VX:
+  case RISCV::VWSUBU_VV:
+  case RISCV::VWSUBU_VX:
+  case RISCV::VWADD_VV:
+  case RISCV::VWADD_VX:
+  case RISCV::VWSUB_VV:
+  case RISCV::VWSUB_VX:
+  case RISCV::VWSLL_VI:
+  // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
+  // Destination EEW=2*SEW and EMUL=2*EMUL. Source EEW=SEW and EMUL=LMUL.
+  case RISCV::VSMUL_VV:
+  case RISCV::VSMUL_VX:
----------------
lukel97 wrote:

Feel free to land first and add a test later, don't let my reviews block this!

https://github.com/llvm/llvm-project/pull/108640


More information about the llvm-commits mailing list