[llvm] [Exegesis][RISCV] Add initial RVV support (PR #128767)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 27 10:07:07 PST 2025
================
@@ -24,10 +32,574 @@
namespace llvm {
namespace exegesis {
+static cl::opt<bool>
+ OnlyUsesVLMAXForVL("riscv-vlmax-for-vl",
+ cl::desc("Only enumerate VLMAX for VL operand"),
+ cl::init(false), cl::Hidden);
+
+static cl::opt<bool>
+ EnumerateRoundingModes("riscv-enumerate-rounding-modes",
+ cl::desc("Enumerate different FRM and VXRM"),
+ cl::init(true), cl::Hidden);
+
+static cl::opt<std::string>
+ FilterConfig("riscv-filter-config",
+ cl::desc("Show only the configs matching this regex"),
+ cl::init(""), cl::Hidden);
+
#include "RISCVGenExegesis.inc"
namespace {
+template <class BaseT> class RISCVSnippetGenerator : public BaseT {
+ static void printRoundingMode(raw_ostream &OS, unsigned Val, bool UsesVXRM) {
+ if (UsesVXRM) {
+ assert(RISCVVXRndMode::isValidRoundingMode(Val));
+ OS << RISCVVXRndMode::roundingModeToString(
+ static_cast<RISCVVXRndMode::RoundingMode>(Val));
+ } else {
+ assert(RISCVFPRndMode::isValidRoundingMode(Val));
+ OS << RISCVFPRndMode::roundingModeToString(
+ static_cast<RISCVFPRndMode::RoundingMode>(Val));
+ }
+ }
+
+ static constexpr unsigned MinSEW = 8;
+ // ELEN is basically SEW_max.
+ unsigned ELEN = 64;
+
+ // We can't know the real min/max VLEN w/o a Function, so we're
+ // using the VLen from Zvl.
+ unsigned ZvlVLen = 32;
+
+ /// Mask for registers that are NOT standalone registers like X0 and V0
+ BitVector AggregateRegisters;
+
+ // Returns true when opcode is available in any of the FBs.
+ static bool
+ isOpcodeAvailableIn(unsigned Opcode,
+ ArrayRef<RISCV_MC::SubtargetFeatureBits> FBs) {
+ FeatureBitset RequiredFeatures = RISCV_MC::computeRequiredFeatures(Opcode);
+ for (uint8_t FB : FBs) {
+ if (RequiredFeatures[FB])
+ return true;
+ }
+ return false;
+ }
+
+ static bool isRVVFloatingPointOp(unsigned Opcode) {
+ return isOpcodeAvailableIn(Opcode,
+ {RISCV_MC::Feature_HasVInstructionsAnyFBit});
+ }
+
+ // Get the element group width of each vector cryptor extension.
+ static unsigned getZvkEGWSize(unsigned Opcode, unsigned SEW) {
+ using namespace RISCV_MC;
+ if (isOpcodeAvailableIn(Opcode, {Feature_HasStdExtZvkgBit,
+ Feature_HasStdExtZvknedBit,
+ Feature_HasStdExtZvksedBit}))
+ return 128U;
+ if (isOpcodeAvailableIn(Opcode, {Feature_HasStdExtZvkshBit}))
+ return 256U;
+ if (isOpcodeAvailableIn(Opcode, {Feature_HasStdExtZvknhaOrZvknhbBit}))
+ // In Zvknh[ab], when SEW=64 is used (i.e. Zvknhb), EGW is 256.
+ // Otherwise it's 128.
+ return SEW == 64 ? 256U : 128U;
+
+ llvm_unreachable("Unsupported opcode");
+ }
+
+ // A handy utility to multiply or divide an integer by LMUL.
+ template <typename T> static T multiplyLMul(T Val, RISCVVType::VLMUL VLMul) {
+ auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul);
+ return IsFractional ? Val / LMul : Val * LMul;
+ }
+
+ /// Return the denominator of the fractional (i.e. the `x` in .vfx suffix) or
+ /// nullopt if BaseOpcode is not a vector sext/zext.
+ static std::optional<unsigned> isRVVSignZeroExtend(unsigned BaseOpcode) {
+ switch (BaseOpcode) {
+ case RISCV::VSEXT_VF2:
+ case RISCV::VZEXT_VF2:
+ return 2;
+ case RISCV::VSEXT_VF4:
+ case RISCV::VZEXT_VF4:
+ return 4;
+ case RISCV::VSEXT_VF8:
+ case RISCV::VZEXT_VF8:
+ return 8;
+ default:
+ return std::nullopt;
+ }
+ }
+
+ void annotateWithVType(const CodeTemplate &CT, const Instruction &Instr,
+ unsigned BaseOpcode,
+ const BitVector &ForbiddenRegisters,
+ std::vector<CodeTemplate> &Result) const;
+
+public:
+ RISCVSnippetGenerator(const LLVMState &State,
+ const SnippetGenerator::Options &Opts)
+ : BaseT(State, Opts),
+ AggregateRegisters(State.getRegInfo().getNumRegs(), /*initVal=*/true) {
+ // Initialize standalone registers mask.
+ const MCRegisterInfo &RegInfo = State.getRegInfo();
+ const unsigned StandaloneRegClasses[] = {
+ RISCV::GPRRegClassID, RISCV::FPR16RegClassID, RISCV::VRRegClassID};
+
+ for (unsigned RegClassID : StandaloneRegClasses)
+ for (unsigned Reg : RegInfo.getRegClass(RegClassID))
+ AggregateRegisters.reset(Reg);
+
+ // Initialize ELEN and VLEN.
+ // FIXME: We could have obtained these two constants from RISCVSubtarget
+ // but in order to get that from TargetMachine, we need a Function.
+ const MCSubtargetInfo &STI = State.getSubtargetInfo();
+ ELEN = STI.checkFeatures("+zve64x") ? 64 : 32;
+
+ std::string ZvlQuery;
+ for (unsigned I = 5U, Size = (1 << I); I < 17U; ++I, Size <<= 1) {
+ ZvlQuery = "+zvl";
+ raw_string_ostream SS(ZvlQuery);
+ SS << Size << "b";
+ if (STI.checkFeatures(SS.str()) && ZvlVLen < Size)
+ ZvlVLen = Size;
+ }
+ }
+
+ Expected<std::vector<CodeTemplate>>
+ generateCodeTemplates(InstructionTemplate Variant,
+ const BitVector &ForbiddenRegisters) const override;
+};
+
+static bool isMaskedSibling(unsigned MaskedOp, unsigned UnmaskedOp) {
+ const auto *RVVMasked = RISCV::getMaskedPseudoInfo(MaskedOp);
+ return RVVMasked && RVVMasked->UnmaskedPseudo == UnmaskedOp;
+}
+
+// There are primarily two kinds of opcodes that are not eligible
+// in a serial snippet:
+// (1) Has a use operand that can not overlap with the def operand
+// (i.e. early clobber).
+// (2) The register file of the only use operand is different from
+// that of the def operand. For instance, use operand is vector and
+// the result is a scalar.
+static bool isIneligibleOfSerialSnippets(unsigned BaseOpcode,
+ const Instruction &I) {
+ if (llvm::any_of(I.Operands,
+ [](const Operand &Op) { return Op.isEarlyClobber(); }))
+ return true;
+
+ switch (BaseOpcode) {
+ case RISCV::VCOMPRESS_VM:
+ case RISCV::VCPOP_M:
+ case RISCV::VCPOP_V:
+ // The permutation instructions listed below cannot have destination
+ // overlapping with the source.
+ case RISCV::VRGATHEREI16_VV:
+ case RISCV::VRGATHER_VI:
+ case RISCV::VRGATHER_VV:
+ case RISCV::VRGATHER_VX:
+ case RISCV::VSLIDE1UP_VX:
+ case RISCV::VSLIDEUP_VI:
+ case RISCV::VSLIDEUP_VX:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static bool isZvfhminZvfbfminOpcodes(unsigned BaseOpcode) {
+ switch (BaseOpcode) {
+ case RISCV::VFNCVT_F_F_W:
+ case RISCV::VFWCVT_F_F_V:
+ case RISCV::VFNCVTBF16_F_F_W:
+ case RISCV::VFWCVTBF16_F_F_V:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static bool isVectorReduction(unsigned BaseOpcode) {
+ switch (BaseOpcode) {
+ case RISCV::VREDAND_VS:
+ case RISCV::VREDMAXU_VS:
+ case RISCV::VREDMAX_VS:
+ case RISCV::VREDMINU_VS:
+ case RISCV::VREDMIN_VS:
+ case RISCV::VREDOR_VS:
+ case RISCV::VREDSUM_VS:
+ case RISCV::VREDXOR_VS:
+ case RISCV::VWREDSUMU_VS:
+ case RISCV::VWREDSUM_VS:
+ case RISCV::VFREDMAX_VS:
+ case RISCV::VFREDMIN_VS:
+ case RISCV::VFREDOSUM_VS:
+ case RISCV::VFREDUSUM_VS:
+ return true;
+ default:
+ return false;
+ }
+}
+
+template <class BaseT>
+void RISCVSnippetGenerator<BaseT>::annotateWithVType(
+ const CodeTemplate &OrigCT, const Instruction &Instr, unsigned BaseOpcode,
+ const BitVector &ForbiddenRegisters,
+ std::vector<CodeTemplate> &Result) const {
+ const MCSubtargetInfo &STI = SnippetGenerator::State.getSubtargetInfo();
+ unsigned VPseudoOpcode = Instr.getOpcode();
+
+ bool IsSerial = std::is_same_v<BaseT, SerialSnippetGenerator>;
+
+ const MCInstrDesc &MIDesc = Instr.Description;
+ const uint64_t TSFlags = MIDesc.TSFlags;
+
+ RISCVVType::VLMUL VLMul = RISCVII::getLMul(TSFlags);
+
+ const size_t StartingResultSize = Result.size();
+
+ SmallPtrSet<const Operand *, 4> VTypeOperands;
+ std::optional<AliasingConfigurations> SelfAliasing;
+ // Exegesis see instructions with tied operands being inherently serial.
+ // But for RVV instructions, those tied operands are passthru rather
+ // than real read operands. So we manually put dependency between
+ // destination (i.e. def) and any of the non-tied/SEW/policy/AVL/RM
+ // operands.
+ auto assignSerialRVVOperands = [&, this](InstructionTemplate &IT) {
+ // Initialize SelfAliasing on first use.
+ if (!SelfAliasing.has_value()) {
+ BitVector ExcludeRegs = ForbiddenRegisters;
+ ExcludeRegs |= AggregateRegisters;
+ SelfAliasing = AliasingConfigurations(Instr, Instr, ExcludeRegs);
+ bool EmptyUses = false;
+ for (auto &ARO : SelfAliasing->Configurations) {
+ auto &Uses = ARO.Uses;
+ for (auto ROA = Uses.begin(); ROA != Uses.end();) {
+ const Operand *Op = ROA->Op;
+ // Exclude tied operand(s).
+ if (Op->isTied()) {
+ ROA = Uses.erase(ROA);
+ continue;
+ }
+
+ // Special handling for reduction operations: for a given reduction
+ // `vredop vd, vs2, vs1`, we don't want vd to be aliased with vs1
+ // since we're only reading `vs1[0]` and many implementations
+ // optimize for this case (e.g. chaining). Instead, we're forcing
+ // it to create alias between vd and vs2.
+ if (isVectorReduction(BaseOpcode) &&
+ // vs1's operand index is always 3.
+ Op->getIndex() == 3) {
+ ROA = Uses.erase(ROA);
+ continue;
+ }
+
+ // Exclude any special operands like SEW and VL -- we've already
+ // assigned values to them.
+ if (VTypeOperands.count(Op)) {
+ ROA = Uses.erase(ROA);
+ continue;
+ }
+ ++ROA;
+ }
+
+ // If any of the use operand candidate lists is empty, there is
+ // no point to assign self aliasing registers.
+ if (Uses.empty()) {
+ EmptyUses = true;
+ break;
+ }
+ }
+ if (EmptyUses)
+ SelfAliasing->Configurations.clear();
+ }
+
+ // This is a self aliasing instruction so defs and uses are from the same
+ // instance, hence twice IT in the following call.
+ if (!SelfAliasing->empty() && !SelfAliasing->hasImplicitAliasing())
+ setRandomAliasing(*SelfAliasing, IT, IT);
+ };
+
+ // We are going to create a CodeTemplate (configuration) for each supported
+ // SEW, policy, and VL.
+ // FIXME: Account for EEW and EMUL.
+ SmallVector<std::optional<unsigned>, 4> Log2SEWs;
+ SmallVector<std::optional<unsigned>, 4> Policies;
+ SmallVector<std::optional<int>, 3> AVLs;
+ SmallVector<std::optional<unsigned>, 8> RoundingModes;
+
+ bool HasSEWOp = RISCVII::hasSEWOp(TSFlags);
+ bool HasPolicyOp = RISCVII::hasVecPolicyOp(TSFlags);
+ bool HasVLOp = RISCVII::hasVLOp(TSFlags);
+ bool HasRMOp = RISCVII::hasRoundModeOp(TSFlags);
+ bool UsesVXRM = RISCVII::usesVXRM(TSFlags);
+
+ if (HasSEWOp) {
+ const Operand &SEWOp = Instr.Operands[RISCVII::getSEWOpNum(MIDesc)];
+ VTypeOperands.insert(&SEWOp);
+
+ if (SEWOp.Info->OperandType == RISCVOp::OPERAND_SEW_MASK) {
+ // If it's a mask-producing instruction, the SEW operand is always zero.
+ Log2SEWs.push_back(0);
+ } else {
+ SmallVector<unsigned, 4> SEWCandidates;
+
+ // (RVV spec 3.4.2) For fractional LMUL, the supported SEW are between
+ // [SEW_min, LMUL * ELEN].
+ unsigned SEWUpperBound =
+ VLMul >= RISCVVType::LMUL_F8 ? multiplyLMul(ELEN, VLMul) : ELEN;
+ for (unsigned SEW = MinSEW; SEW <= SEWUpperBound; SEW <<= 1) {
+ SEWCandidates.push_back(SEW);
+
+ // Some scheduling classes already integrate SEW; only put
+ // their corresponding SEW values at the SEW operands.
+ // NOTE: It is imperative to put this condition in the front, otherwise
+ // it is tricky and difficult to know if there is an integrated
+ // SEW after other rules are applied to filter the candidates.
+ const auto *RVVBase =
+ RISCVVInversePseudosTable::getBaseInfo(BaseOpcode, VLMul, SEW);
+ if (RVVBase && (RVVBase->Pseudo == VPseudoOpcode ||
+ isMaskedSibling(VPseudoOpcode, RVVBase->Pseudo) ||
+ isMaskedSibling(RVVBase->Pseudo, VPseudoOpcode))) {
+ // There is an integrated SEW, remove all but the SEW pushed last.
+ SEWCandidates.erase(SEWCandidates.begin(), SEWCandidates.end() - 1);
+ break;
+ }
+ }
+
+ // Filter out some candidates.
+ for (auto SEW = SEWCandidates.begin(); SEW != SEWCandidates.end();) {
+ // For floating point operations, only select SEW of the supported FLEN.
+ if (isRVVFloatingPointOp(VPseudoOpcode)) {
+ bool Supported = false;
+ Supported |= isZvfhminZvfbfminOpcodes(BaseOpcode) && *SEW == 16;
+ Supported |= STI.hasFeature(RISCV::FeatureStdExtZvfh) && *SEW == 16;
+ Supported |= STI.hasFeature(RISCV::FeatureStdExtF) && *SEW == 32;
+ Supported |= STI.hasFeature(RISCV::FeatureStdExtD) && *SEW == 64;
+ if (!Supported) {
+ SEW = SEWCandidates.erase(SEW);
+ continue;
+ }
+ }
+
+ // The EEW for source operand in VSEXT and VZEXT is a fraction
+ // of the SEW, hence only SEWs that will lead to valid EEW are allowed.
+ if (auto Frac = isRVVSignZeroExtend(BaseOpcode))
+ if (*SEW / *Frac < MinSEW) {
+ SEW = SEWCandidates.erase(SEW);
+ continue;
+ }
+
+ // Most vector crypto 1.0 instructions only work on SEW=32.
+ using namespace RISCV_MC;
+ if (isOpcodeAvailableIn(BaseOpcode, {Feature_HasStdExtZvkgBit,
+ Feature_HasStdExtZvknedBit,
+ Feature_HasStdExtZvknhaOrZvknhbBit,
+ Feature_HasStdExtZvksedBit,
+ Feature_HasStdExtZvkshBit})) {
+ if (*SEW != 32)
+ // Zvknhb supports SEW=64 as well.
+ if (*SEW != 64 || !STI.hasFeature(RISCV::FeatureStdExtZvknhb) ||
+ !isOpcodeAvailableIn(BaseOpcode,
+ {Feature_HasStdExtZvknhaOrZvknhbBit})) {
+ SEW = SEWCandidates.erase(SEW);
+ continue;
+ }
+
+ // We're also enforcing the requirement of `LMUL * VLEN >= EGW` here,
+ // because some of the extensions have SEW-dependant EGW.
+ unsigned EGW = getZvkEGWSize(BaseOpcode, *SEW);
+ if (multiplyLMul(ZvlVLen, VLMul) < EGW) {
+ SEW = SEWCandidates.erase(SEW);
+ continue;
+ }
+ }
+
+ ++SEW;
+ }
+
+ // We're not going to produce any result with zero SEW candidate.
+ if (SEWCandidates.empty())
+ return;
+
+ for (unsigned SEW : SEWCandidates)
+ Log2SEWs.push_back(Log2_32(SEW));
+ }
+ } else {
+ Log2SEWs.push_back(std::nullopt);
+ }
+
+ if (HasPolicyOp) {
+ VTypeOperands.insert(&Instr.Operands[RISCVII::getVecPolicyOpNum(MIDesc)]);
+
+ Policies = {0, RISCVVType::TAIL_AGNOSTIC, RISCVVType::MASK_AGNOSTIC,
+ (RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC)};
+ } else {
+ Policies.push_back(std::nullopt);
+ }
+
+ if (HasVLOp) {
+ VTypeOperands.insert(&Instr.Operands[RISCVII::getVLOpNum(MIDesc)]);
+
+ if (OnlyUsesVLMAXForVL)
+ AVLs.push_back(-1);
+ else
+ AVLs = {// 5-bit immediate value
+ 1,
+ // VLMAX
+ -1,
+ // Non-X0 register
+ 0};
+ } else {
+ AVLs.push_back(std::nullopt);
+ }
+
+ if (HasRMOp) {
+ VTypeOperands.insert(&Instr.Operands[RISCVII::getVLOpNum(MIDesc) - 1]);
+
+ // If we're not enumerating all rounding modes,
+ // use zero (rne in FRM and rnu in VXRM) as the default
+ // mode.
+ RoundingModes = {0U};
----------------
topperc wrote:
Use the name constants from the enums? You'll have to pull the `UsesVXRM` check out and duplicate some code for each.
https://github.com/llvm/llvm-project/pull/128767
More information about the llvm-commits
mailing list