[llvm] [AMDGPU] Introduce iglp_opt(2): Generalized exp/mfma interleaving for select kernels (PR #81342)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 22 13:18:37 PST 2024
================
@@ -902,6 +906,923 @@ void MFMASmallGemmOpt::applyIGLPStrategy(
SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
}
+
+ return true;
+}
+
+class MFMAExpInterleaveOpt final : public IGLPStrategy {
+private:
+ // The count of TRANS SUs involved in the interleaved pipeline
+ static unsigned TransPipeCount;
+ // The count of MFMA SUs involved in the interleaved pipeline
+ static unsigned MFMAPipeCount;
+ // The number of transitive MFMA successors for each TRANS SU
+ static unsigned MFMAEnablement;
+ // The number of transitive TRANS predecessors for each MFMA SU
+ static unsigned ExpRequirement;
+ // The count of independent "chains" of MFMA instructions in the pipeline
+ static unsigned MFMAChains;
+ // The length of each independent "chain" of MFMA instructions
+ static unsigned MFMAChainLength;
+ // Whether or not the pipeline has V_CVT instructions
+ static bool HasCvt;
+ // Whether or not there are instructions between the TRANS instruction and
+ // V_CVT
+ static bool HasChainBetweenCvt;
+ // The first occuring DS_READ which feeds an MFMA chain
+ static std::optional<unsigned> FirstPipeDSR;
+ SmallVector<SUnit *, 4> MFMAChainSeeds;
+ // Compute the heuristics for the pipeline, returning whether or not the DAG
+ // is well formatted for the mutation
+ bool analyzeDAG(const SIInstrInfo *TII);
+
+ /// Whether or not the instruction is a transitive predecessor of an MFMA
+ /// instruction
+ class IsPipeExp final : public InstructionRule {
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+ auto DAG = SyncPipe[0].DAG;
+
+ if (Cache->empty()) {
+ auto I = DAG->SUnits.rbegin();
+ auto E = DAG->SUnits.rend();
+ for (; I != E; I++) {
+ if (TII->isMFMAorWMMA(*I->getInstr()))
+ Cache->push_back(&*I);
+ }
+ if (Cache->empty())
+ return false;
+ }
+
+ auto Reaches = (std::any_of(
+ Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
+ return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
+ }));
+
+ return Reaches;
+ }
+ IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache) {}
+ };
+
+ /// Whether or not the instruction is a transitive predecessor of the
+ /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
+ class EnablesNthMFMA final : public InstructionRule {
+ private:
+ unsigned Number = 1;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ bool FoundTrans = false;
+ unsigned Counter = 1;
+ auto DAG = SyncPipe[0].DAG;
+
+ if (Cache->empty()) {
+ SmallVector<SUnit *, 8> Worklist;
+
+ auto I = DAG->SUnits.begin();
+ auto E = DAG->SUnits.end();
+ for (; I != E; I++) {
+ if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
+ if (Counter == Number) {
+ Cache->push_back(&*I);
+ break;
+ }
+ ++Counter;
+ }
+ if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
+ FoundTrans = true;
+ }
+ if (Cache->empty())
+ return false;
+ }
+
+ return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+ }
+
+ EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+ bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+ };
+
+ /// Whether or not the instruction enables the exact MFMA that is the \p
+ /// Number th MFMA in the chain starting with \p ChainSeed
+ class EnablesNthMFMAInChain final : public InstructionRule {
+ private:
+ unsigned Number = 1;
+ SUnit *ChainSeed;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ auto DAG = SyncPipe[0].DAG;
+
+ if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+ return false;
+
+ if (Cache->empty()) {
+ auto TempSU = ChainSeed;
+ auto Depth = Number;
+ while (Depth > 0) {
+ --Depth;
+ bool Found = false;
+ for (auto &Succ : TempSU->Succs) {
+ if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+ TempSU = Succ.getSUnit();
+ Found = true;
+ break;
+ }
+ }
+ if (!Found)
+ return false;
+ }
+
+ Cache->push_back(TempSU);
+ }
+ // If we failed to find the instruction to be placed into the cache, we
+ // would have already exited.
+ assert(!Cache->empty());
+
+ return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+ }
+
+ EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
+ const SIInstrInfo *TII, unsigned SGID,
+ bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+ ChainSeed(ChainSeed) {}
+ };
+
+ /// Whether or not the instruction has less than \p Size immediate successors.
+ /// If \p HasIntermediary is true, this tests also whether all successors of
+ /// the SUnit have less than \p Size successors.
+ class LessThanNSuccs final : public InstructionRule {
+ private:
+ unsigned Size = 1;
+ bool HasIntermediary = false;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ if (!SyncPipe.size())
+ return false;
+
+ auto SuccSize = std::count_if(
+ SU->Succs.begin(), SU->Succs.end(),
+ [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+ if (SuccSize >= Size)
+ return false;
+
+ if (HasIntermediary) {
+ for (auto Succ : SU->Succs) {
+ auto SuccSize = std::count_if(
+ Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+ [](const SDep &SuccSucc) {
+ return SuccSucc.getKind() == SDep::Data;
+ });
+ if (SuccSize >= Size)
+ return false;
+ }
+ }
+
+ return true;
+ }
+ LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
+ bool HasIntermediary = false, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+ HasIntermediary(HasIntermediary) {}
+ };
+
+ /// Whether or not the instruction has greater than or equal to \p Size
+ /// immediate successors. If \p HasIntermediary is true, this tests also
+ /// whether all successors of the SUnit have greater than or equal to \p Size
+ /// successors.
+ class GreaterThanOrEqualToNSuccs final : public InstructionRule {
+ private:
+ unsigned Size = 1;
+ bool HasIntermediary = false;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ if (!SyncPipe.size())
+ return false;
+
+ auto SuccSize = std::count_if(
+ SU->Succs.begin(), SU->Succs.end(),
+ [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+ if (SuccSize >= Size)
+ return true;
+
+ if (HasIntermediary) {
+ for (auto Succ : SU->Succs) {
+ auto SuccSize = std::count_if(
+ Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+ [](const SDep &SuccSucc) {
+ return SuccSucc.getKind() == SDep::Data;
+ });
+ if (SuccSize >= Size)
+ return true;
+ }
+ }
+
+ return false;
+ }
+ GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
+ unsigned SGID, bool HasIntermediary = false,
+ bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+ HasIntermediary(HasIntermediary) {}
+ };
+
+ // Whether or not the instruction is a relevant V_CVT instruction.
+ class IsCvt final : public InstructionRule {
+ private:
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ auto Opc = SU->getInstr()->getOpcode();
+ return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
+ Opc == AMDGPU::V_CVT_I32_F32_e32;
+ }
+ IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache) {}
+ };
+
+ // Whether or not the instruction is V_FMA_F32.
+ class IsFMA final : public InstructionRule {
+ private:
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
+ SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
+ }
+ IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache) {}
+ };
+
+ /// Whether or not the instruction is an immediate RAW successor
+ /// of the SchedGroup \p Distance steps before.
+ class IsSuccOfPrevNthGroup final : public InstructionRule {
+ private:
+ unsigned Distance = 1;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ SchedGroup *OtherGroup = nullptr;
+ if (!SyncPipe.size())
+ return false;
+
+ for (auto &PipeSG : SyncPipe) {
+ if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+ OtherGroup = &PipeSG;
+ }
+
+ if (!OtherGroup)
+ return false;
+ if (!OtherGroup->Collection.size())
+ return true;
+
+ for (auto &OtherEle : OtherGroup->Collection) {
+ for (auto &Succ : OtherEle->Succs) {
+ if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
+ return true;
+ }
+ }
+
+ return false;
+ }
+ IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+ unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+ };
+
+ /// Whether or not the instruction is a transitive successor of any
+ /// instruction the the SchedGroup \p Distance steps before.
+ class IsReachableFromPrevNthGroup final : public InstructionRule {
+ private:
+ unsigned Distance = 1;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ SchedGroup *OtherGroup = nullptr;
+ if (!SyncPipe.size())
+ return false;
+
+ for (auto &PipeSG : SyncPipe) {
+ if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+ OtherGroup = &PipeSG;
+ }
+
+ if (!OtherGroup)
+ return false;
+ if (!OtherGroup->Collection.size())
+ return true;
+
+ auto DAG = SyncPipe[0].DAG;
+
+ for (auto &OtherEle : OtherGroup->Collection)
+ if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
+ return true;
+
+ return false;
+ }
+ IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+ unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+ };
+
+ /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
+ class OccursAtOrAfterNode final : public InstructionRule {
+ private:
+ unsigned Number = 1;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+ return SU->NodeNum >= Number;
+ }
+ OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+ bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+ };
+
+ /// Whether or not the SU is exactly the \p Number th MFMA in the chain
+ /// starting with \p ChainSeed
+ class IsExactMFMA final : public InstructionRule {
+ private:
+ unsigned Number = 1;
+ SUnit *ChainSeed;
+
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+ if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+ return false;
+
+ if (Cache->empty()) {
+ auto TempSU = ChainSeed;
+ auto Depth = Number;
+ while (Depth > 0) {
+ --Depth;
+ bool Found = false;
+ for (auto &Succ : TempSU->Succs) {
+ if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+ TempSU = Succ.getSUnit();
+ Found = true;
+ break;
+ }
+ }
+ if (!Found) {
+ return false;
+ }
+ }
+ Cache->push_back(TempSU);
+ }
+
+ if (Cache->empty())
+ return false;
+
+ return (*Cache)[0] == SU;
+ }
+
+ IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
+ unsigned SGID, bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+ ChainSeed(ChainSeed) {}
+ };
+
+ // Whether the instruction occurs after the first TRANS instruction. This
+ // implies the instruction can not be a predecessor of the first TRANS
+ // insruction
+ class OccursAfterExp final : public InstructionRule {
+ public:
+ bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+ SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+ SmallVector<SUnit *, 12> Worklist;
+ auto DAG = SyncPipe[0].DAG;
+ if (Cache->empty()) {
+ for (auto &SU : DAG->SUnits)
+ if (TII->isTRANS(SU.getInstr()->getOpcode())) {
+ Cache->push_back(&SU);
+ break;
+ }
+ }
+
+ if (Cache->empty())
+ return false;
+
+ return SU->NodeNum > (*Cache)[0]->NodeNum;
+ }
+
+ OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
+ bool NeedsCache = false)
+ : InstructionRule(TII, SGID, NeedsCache) {}
+ };
+
+public:
+ bool applyIGLPStrategy(
+ DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
+ DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
+ AMDGPU::SchedulingPhase Phase) override;
+
+ bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
+ AMDGPU::SchedulingPhase Phase) override;
+
+ MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
+ : IGLPStrategy(DAG, TII) {
+ IsBottomUp = 0;
+ }
+};
+
+unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
+unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
+bool MFMAExpInterleaveOpt::HasCvt = false;
+bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
+std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
+
+bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
+ SmallVector<SUnit *, 10> ExpPipeCands;
+ SmallVector<SUnit *, 10> MFMAPipeCands;
+ SmallVector<SUnit *, 10> MFMAPipeSUs;
+ SmallVector<SUnit *, 10> PackSUs;
+ SmallVector<SUnit *, 10> CvtSUs;
+
+ auto isBitPack = [](unsigned Opc) {
+ return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
+ };
+
+ auto isCvt = [](unsigned Opc) {
+ return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
+ };
+
+ for (SUnit &SU : DAG->SUnits) {
+ auto Opc = SU.getInstr()->getOpcode();
+ if (TII->isTRANS(Opc)) {
+ // Avoid counting a potential bonus V_EXP which all the MFMA depend on
+ if (SU.Succs.size() >= 7)
+ continue;
+ for (auto &Succ : SU.Succs) {
+ if (Succ.getSUnit()->Succs.size() >= 7)
+ continue;
+ }
+ ExpPipeCands.push_back(&SU);
+ }
+
+ if (TII->isMFMAorWMMA(*SU.getInstr()))
+ MFMAPipeCands.push_back(&SU);
+
+ if (isBitPack(Opc))
+ PackSUs.push_back(&SU);
+
+ if (isCvt(Opc))
+ CvtSUs.push_back(&SU);
+ }
+
+ if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
+ return false;
+
+ TransPipeCount = 0;
+
+ std::optional<SUnit *> TempMFMA;
+ std::optional<SUnit *> TempExp;
+ // Count the number of EXPs that reach an MFMA
+ for (auto &PredSU : ExpPipeCands) {
+ for (auto &SuccSU : MFMAPipeCands) {
+ if (DAG->IsReachable(SuccSU, PredSU)) {
+ if (!TempExp.has_value()) {
+ TempExp = PredSU;
+ TempMFMA = SuccSU;
+ }
+ MFMAPipeSUs.push_back(SuccSU);
+ ++TransPipeCount;
+ break;
+ }
+ }
+ }
+
+ if (!TempExp.has_value())
+ return false;
+
+ HasChainBetweenCvt =
+ std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
+ [&isCvt](SDep &Succ) {
+ return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
+ }) == (*TempExp)->Succs.end();
+
+ // Count the number of MFMAs that are reached by an EXP
+ MFMAPipeCount = 0;
+ for (auto &SuccSU : MFMAPipeCands) {
----------------
jrbyrnes wrote:
Yes, MFMAPipeCnads are all MFMA, and MFMAPipeSUs are the MFMA which have V_EXP (transitive) predcessor.
So yes, We can more efficiently determine MFMAPipeCount in this way.
https://github.com/llvm/llvm-project/pull/81342
More information about the llvm-commits
mailing list