[llvm] [RISCV] Introduce VLOptimizer pass (PR #108640)
Pengcheng Wang via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 13 21:02:57 PDT 2024
================
@@ -0,0 +1,1569 @@
+//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This pass reduces the VL where possible at the MI level, before VSETVLI
+// instructions are inserted.
+//
+// The purpose of this optimization is to make the VL argument, for instructions
+// that have a VL argument, as small as possible. This is implemented by
+// visiting each instruction in reverse order and checking that if it has a VL
+// argument, whether the VL can be reduced.
+//
+//===---------------------------------------------------------------------===//
+
+#include "RISCV.h"
+#include "RISCVMachineFunctionInfo.h"
+#include "RISCVSubtarget.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/InitializePasses.h"
+
+#include <algorithm>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-vl-optimizer"
+
+namespace {
+
+class RISCVVLOptimizer : public MachineFunctionPass {
+ const MachineRegisterInfo *MRI;
+ const MachineDominatorTree *MDT;
+
+public:
+ static char ID;
+
+ RISCVVLOptimizer() : MachineFunctionPass(ID) {}
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<MachineDominatorTreeWrapperPass>();
+ MachineFunctionPass::getAnalysisUsage(AU);
+ }
+
+ StringRef getPassName() const override { return "RISC-V VL Optimizer"; }
+
+private:
+ bool tryReduceVL(MachineInstr &MI);
+ bool isCandidate(const MachineInstr &MI) const;
+};
+
+} // end anonymous namespace
+
+char RISCVVLOptimizer::ID = 0;
+INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, "RISC-V VL Optimizer",
+ false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, "RISC-V VL Optimizer", false,
+ false)
+
+FunctionPass *llvm::createRISCVVLOptimizerPass() {
+ return new RISCVVLOptimizer();
+}
+
+/// Return true if R is a physical or virtual vector register, false otherwise.
+static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
+ if (R.isPhysical())
+ return RISCV::VRRegClass.contains(R);
+ const TargetRegisterClass *RC = MRI->getRegClass(R);
+ return RISCV::VRRegClass.hasSubClassEq(RC) ||
+ RISCV::VRM2RegClass.hasSubClassEq(RC) ||
+ RISCV::VRM4RegClass.hasSubClassEq(RC) ||
+ RISCV::VRM8RegClass.hasSubClassEq(RC);
+}
+
+/// Represents the EMUL and EEW of a MachineOperand.
+struct OperandInfo {
+ enum class State {
+ Unknown,
+ Known,
+ } S;
+
+ // Represent as 1,2,4,8, ... and fractional indicator. This is because
+ // EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
+ // For example, a mask operand can have an EMUL less than MF8.
+ std::pair<unsigned, bool> EMUL;
+
+ unsigned Log2EEW;
+
+ OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
+ : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {
+ }
+
+ OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
+ : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}
+
+ OperandInfo(State S) : S(S) {
+ assert(S != State::Known &&
+ "This constructor may only be used to construct "
+ "an Unknown OperandInfo");
+ }
+
+ bool isUnknown() const { return S == State::Unknown; }
+ bool isKnown() const { return S == State::Known; }
+
+ static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
+ assert(A.isKnown() && B.isKnown() && "Both operands must be known");
+ return A.Log2EEW == B.Log2EEW && A.EMUL.first == B.EMUL.first &&
+ A.EMUL.second == B.EMUL.second;
+ }
+
+ void print(raw_ostream &OS) const {
+ if (isUnknown()) {
+ OS << "Unknown";
+ return;
+ }
+ OS << "EMUL: ";
+ if (EMUL.second)
+ OS << "m";
+ OS << "f" << EMUL.first;
+ OS << ", EEW: " << (1 << Log2EEW);
+ }
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
+ OI.print(OS);
+ return OS;
+}
+
+/// Return the RISCVII::VLMUL that is two times VLMul.
+/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8.
+static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) {
+ switch (VLMul) {
+ case RISCVII::VLMUL::LMUL_F8:
+ return RISCVII::VLMUL::LMUL_F4;
+ case RISCVII::VLMUL::LMUL_F4:
+ return RISCVII::VLMUL::LMUL_F2;
+ case RISCVII::VLMUL::LMUL_F2:
+ return RISCVII::VLMUL::LMUL_1;
+ case RISCVII::VLMUL::LMUL_1:
+ return RISCVII::VLMUL::LMUL_2;
+ case RISCVII::VLMUL::LMUL_2:
+ return RISCVII::VLMUL::LMUL_4;
+ case RISCVII::VLMUL::LMUL_4:
+ return RISCVII::VLMUL::LMUL_8;
+ case RISCVII::VLMUL::LMUL_8:
+ default:
+ llvm_unreachable("Could not multiply VLMul by 2");
+ }
+}
+
+/// Return the RISCVII::VLMUL that is VLMul / 2.
+/// Precondition: VLMul is not LMUL_RESERVED or LMUL_MF8.
+static RISCVII::VLMUL halfVLMUL(RISCVII::VLMUL VLMul) {
+ switch (VLMul) {
+ case RISCVII::VLMUL::LMUL_F4:
+ return RISCVII::VLMUL::LMUL_F8;
+ case RISCVII::VLMUL::LMUL_F2:
+ return RISCVII::VLMUL::LMUL_F4;
+ case RISCVII::VLMUL::LMUL_1:
+ return RISCVII::VLMUL::LMUL_F2;
+ case RISCVII::VLMUL::LMUL_2:
+ return RISCVII::VLMUL::LMUL_1;
+ case RISCVII::VLMUL::LMUL_4:
+ return RISCVII::VLMUL::LMUL_2;
+ case RISCVII::VLMUL::LMUL_8:
+ return RISCVII::VLMUL::LMUL_4;
+ case RISCVII::VLMUL::LMUL_F8:
+ default:
+ llvm_unreachable("Could not divide VLMul by 2");
+ }
+}
+
+/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
+/// SEW are from the TSFlags of MI.
+static std::pair<unsigned, bool>
+getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
+ RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
+ auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+ unsigned MISEW = 1 << MILog2SEW;
+
+ unsigned EEW = 1 << Log2EEW;
+ // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
+ // to put fraction in simplest form.
+ unsigned Num = EEW, Denom = MISEW;
+ int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
+ : std::gcd(Num * MILMUL, Denom);
+ Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
+ Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
+ return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
+}
+
+static bool isOpN(const MachineOperand &MO, unsigned OpN) {
+ const MachineInstr &MI = *MO.getParent();
+ bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
+
+ if (HasPassthru)
+ return MO.getOperandNo() == OpN + 1;
+
+ return MO.getOperandNo() == OpN;
+}
+
+/// An index segment load or store operand has the form v.*seg<nf>ei<eeew>.v.
+/// Data has EEW=SEW, EMUL=LMUL. Index has EEW=<eew>, EMUL=(EEW/SEW)*LMUL. LMUL
+/// and SEW comes from TSFlags of MI.
+static OperandInfo getIndexSegmentLoadStoreOperandInfo(unsigned Log2EEW,
+ const MachineInstr &MI,
+ const MachineOperand &MO,
+ bool IsLoad) {
+ // Operand 0 is data register
+ // Data vector register group has EEW=SEW, EMUL=LMUL.
+ if (MO.getOperandNo() == 0) {
+ RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+ return OperandInfo(MIVLMul, MILog2SEW);
+ }
+
+ // Operand 2 is index vector register
+ // v.*seg<nf>ei<eeew>.v
+ // Index vector register group has EEW=<eew>, EMUL=(EEW/SEW)*LMUL.
+ if (isOpN(MO, 2))
+ return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), Log2EEW);
+
+ llvm_unreachable("Could not get OperandInfo for non-vector register of an "
+ "indexed segment load or store instruction");
+}
+
+/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
+/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
+static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
+ const MachineInstr &MI,
+ const MachineOperand &MO) {
+ RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
+ unsigned MILog2SEW =
+ MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+
+ if (MO.getOperandNo() == 0)
+ return OperandInfo(MIVLMul, MILog2SEW);
+
+ unsigned MISEW = 1 << MILog2SEW;
+ unsigned EEW = MISEW / Factor;
+ unsigned Log2EEW = Log2_32(EEW);
+
+ return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), Log2EEW);
+}
+
+/// Check whether MO is a mask operand of MI.
+static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
+ const MachineRegisterInfo *MRI) {
+
+ if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
+ return false;
+
+ const MCInstrDesc &Desc = MI.getDesc();
+ return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
+}
+
+/// Return the OperandInfo for MO, which is an operand of MI.
+static OperandInfo getOperandInfo(const MachineInstr &MI,
----------------
wangpc-pp wrote:
Maybe we can reuse some parts of https://github.com/llvm/llvm-project/pull/105945?
https://github.com/llvm/llvm-project/pull/108640
More information about the llvm-commits
mailing list