[llvm] [AMDGPU] Introduce iglp_opt(2): Generalized exp/mfma interleaving for select kernels (PR #81342)
    Jeffrey Byrnes via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Thu Feb 22 09:46:57 PST 2024
    
    
  
================
@@ -902,6 +904,926 @@ void MFMASmallGemmOpt::applyIGLPStrategy(
         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
   }
+
+  return true;
+}
+
+class MFMAExpInterleaveOpt final : public IGLPStrategy {
+private:
+  // The count of TRANS SUs involved in the interleaved pipeline
+  static unsigned TransPipeCount;
+  // The count of MFMA SUs involved in the interleaved pipeline
+  static unsigned MFMAPipeCount;
+  // The number of transitive MFMA successors for each TRANS SU
+  static unsigned MFMAEnablement;
+  // The number of transitive TRANS predecessors for each MFMA SU
+  static unsigned ExpRequirement;
+  // The count of independent "chains" of MFMA instructions in the pipeline
+  static unsigned MFMAChains;
+  // The length of each independent "chain" of MFMA instructions
+  static unsigned MFMAChainLength;
+  // Whether or not the pipeline has V_CVT instructions
+  static bool HasCvt;
+  // Whether or not there are instructions between the TRANS instruction and
+  // V_CVT
+  static bool HasChainBetweenCvt;
+  // The first occuring DS_READ which feeds an MFMA chain
+  static std::optional<unsigned> FirstPipeDSR;
+  SmallVector<SUnit *, 4> MFMAChainSeeds;
+  // Compute the heuristics for the pipeline, returning whether or not the DAG
+  // is well formatted for the mutation
+  bool analyzeDAG(const SIInstrInfo *TII);
+
+  /// Whether or not the instruction is a transitive predecessor of an MFMA
+  /// instruction
+  class IsPipeExp final : public InstructionRule {
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+
+      if (Cache->empty()) {
+        auto I = DAG->SUnits.rbegin();
+        auto E = DAG->SUnits.rend();
+        for (; I != E; I++) {
+          if (TII->isMFMAorWMMA(*I->getInstr()))
+            Cache->push_back(&*I);
+        }
+        if (Cache->empty())
+          return false;
+      }
+
+      auto Reaches = (std::any_of(
+          Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
+            return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
+          }));
+
+      return Reaches;
+    }
+    IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  /// Whether or not the instruction enables the exact MFMA that is the \p
+  /// Number th MFMA in the chain starting with \p ChainSeed
+  class EnablesNthMFMA final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      bool FoundTrans = false;
+      unsigned Counter = 1;
+      auto DAG = SyncPipe[0].DAG;
+
+      if (Cache->empty()) {
+        auto TII = SyncPipe[0].TII;
+        SmallVector<SUnit *, 8> Worklist;
+
+        auto I = DAG->SUnits.begin();
+        auto E = DAG->SUnits.end();
+        for (; I != E; I++) {
+          if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
+            if (Counter == Number) {
+              Cache->push_back(&*I);
+              break;
+            }
+            ++Counter;
+          }
+          if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
+            FoundTrans = true;
+        }
+      }
+      if (Cache->empty())
+        return false;
+
+      return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+    }
+
+    EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+                   bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+  };
+
+  /// Whether or not the instruction enables  a transitive predecessor of the
+  /// same MFMA instruction as an instruction in a SchedGroup \p Number steps
+  /// before
+  class EnablesNthMFMAInChain final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+    SUnit *ChainSeed;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+
+      if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+        return false;
+
+      if (Cache->empty()) {
+        auto TempSU = ChainSeed;
+        auto Depth = Number;
+        while (Depth > 0) {
+          --Depth;
+          bool Found = false;
+          for (auto &Succ : TempSU->Succs) {
+            if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+              TempSU = Succ.getSUnit();
+              Found = true;
+              break;
+            }
+          }
+          if (!Found)
+            return false;
+        }
+
+        Cache->push_back(TempSU);
+      }
+      if (Cache->empty())
+        return false;
+
+      return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+    }
+
+    EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
+                          const SIInstrInfo *TII, unsigned SGID,
+                          bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+          ChainSeed(ChainSeed) {}
+  };
+
+  /// Whether or not the instruction has less than \p Size immediate successors.
+  /// If \p HasIntermediary is true, this tests also whether all successors of
+  /// the SUnit have less than \p Size successors.
+  class LessThanNSuccs final : public InstructionRule {
+  private:
+    unsigned Size = 1;
+    bool HasIntermediary = false;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      if (!SyncPipe.size())
+        return false;
+
+      auto SuccSize = std::count_if(
+          SU->Succs.begin(), SU->Succs.end(),
+          [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+      if (SuccSize >= Size)
+        return false;
+
+      if (HasIntermediary) {
+        for (auto Succ : SU->Succs) {
+          auto SuccSize = std::count_if(
+              Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+              [](const SDep &SuccSucc) {
+                return SuccSucc.getKind() == SDep::Data;
+              });
+          if (SuccSize >= Size)
+            return false;
+        }
+      }
+
+      return true;
+    }
+    LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
+                   bool HasIntermediary = false, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+          HasIntermediary(HasIntermediary) {}
+  };
+
+  /// Whether or not the instruction has greater than or equal to \p Size
+  /// immediate successors. If \p HasIntermediary is true, this tests also
+  /// whether all successors of the SUnit have greater than or equal to \p Size
+  /// successors.
+  class GreaterThanNSuccs final : public InstructionRule {
+  private:
+    unsigned Size = 1;
+    bool HasIntermediary = false;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      if (!SyncPipe.size())
+        return false;
+
+      auto SuccSize = std::count_if(
+          SU->Succs.begin(), SU->Succs.end(),
+          [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+      if (SuccSize >= Size)
+        return true;
+
+      if (HasIntermediary) {
+        for (auto Succ : SU->Succs) {
+          auto SuccSize = std::count_if(
+              Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+              [](const SDep &SuccSucc) {
+                return SuccSucc.getKind() == SDep::Data;
+              });
+          if (SuccSize >= Size)
+            return true;
+        }
+      }
+
+      return false;
+    }
+    GreaterThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
+                      bool HasIntermediary = false, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+          HasIntermediary(HasIntermediary) {}
+  };
+
+  // Whether or not the instruction is an V_CVT instruction.
+  class IsCvt final : public InstructionRule {
+  private:
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto Opc = SU->getInstr()->getOpcode();
+      return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
+             Opc == AMDGPU::V_CVT_I32_F32_e32;
+    }
+    IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  // Whether or not the instruction is an V_FMA_F32 instruction.
+  class IsFMA final : public InstructionRule {
+  private:
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
+             SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
+    }
+    IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  /// Whether or not the instruction is an immediate RAW successor
+  /// of the SchedGroup \p Distance steps before.
+  class IsSuccOfPrevNthGroup final : public InstructionRule {
+  private:
+    unsigned Distance = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      SchedGroup *OtherGroup = nullptr;
+      if (!SyncPipe.size())
+        return false;
+
+      for (auto &PipeSG : SyncPipe) {
+        if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+          OtherGroup = &PipeSG;
+      }
+
+      if (!OtherGroup)
+        return false;
+      if (!OtherGroup->Collection.size())
+        return true;
+
+      for (auto &OtherEle : OtherGroup->Collection) {
+        for (auto &Succ : OtherEle->Succs) {
+          if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
+            return true;
+        }
+      }
+
+      return false;
+    }
+    IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+                         unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+  };
+
+  /// Whether or not the instruction is a transitive successor of any
+  /// instruction the the SchedGroup \p Distance steps before.
+  class IsReachableFromPrevNthGroup final : public InstructionRule {
+  private:
+    unsigned Distance = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      SchedGroup *OtherGroup = nullptr;
+      if (!SyncPipe.size())
+        return false;
+
+      for (auto &PipeSG : SyncPipe) {
+        if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+          OtherGroup = &PipeSG;
+      }
+
+      if (!OtherGroup)
+        return false;
+      if (!OtherGroup->Collection.size())
+        return true;
+
+      auto DAG = SyncPipe[0].DAG;
+
+      for (auto &OtherEle : OtherGroup->Collection)
+        if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
+          return true;
+
+      return false;
+    }
+    IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+                                unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+  };
+
+  /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
+  class OccursAfterDSRead final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      return SU->NodeNum >= Number;
+    }
+    OccursAfterDSRead(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+                      bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+  };
+
+  /// Whether or not the SU is exactly the \p Number th MFMA in the chain
+  /// starting with \p ChainSeed
+  class IsExactMFMA final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+    SUnit *ChainSeed;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto TII = SyncPipe[0].TII;
+      if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+        return false;
+
+      if (Cache->empty()) {
+        auto TempSU = ChainSeed;
+        auto Depth = Number;
+        while (Depth > 0) {
+          --Depth;
+          bool Found = false;
+          for (auto &Succ : TempSU->Succs) {
+            if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+              TempSU = Succ.getSUnit();
+              Found = true;
+              break;
+            }
+          }
+          if (!Found) {
+            return false;
+          }
+        }
+        Cache->push_back(TempSU);
+      }
+
+      if (Cache->empty())
+        return false;
+
+      return (*Cache)[0] == SU;
+    }
+
+    IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
+                unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+          ChainSeed(ChainSeed) {}
+  };
+
+  // Whether the instruction occurs after the first TRANS instruction. This
+  // implies the instruction can not be a predecessor of the first TRANS
+  // insruction
+  class OccursAfterExp final : public InstructionRule {
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      SmallVector<SUnit *, 12> Worklist;
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+      if (Cache->empty()) {
+        for (auto &SU : DAG->SUnits)
+          if (TII->isTRANS(SU.getInstr()->getOpcode())) {
+            Cache->push_back(&SU);
+            break;
+          }
+      }
+
+      if (Cache->empty())
+        return false;
+
+      return SU->NodeNum > (*Cache)[0]->NodeNum;
+    }
+
+    OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
+                   bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+public:
+  bool applyIGLPStrategy(
+      DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
+      DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
+      IGLPPhase Phase) override;
+
+  bool shouldApplyStrategy(ScheduleDAGInstrs *DAG, IGLPPhase Phase) override;
+
+  MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
+      : IGLPStrategy(DAG, TII) {
+    IsBottomUp = 0;
+  }
+};
+
+unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
+unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
+bool MFMAExpInterleaveOpt::HasCvt = false;
+bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
+std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
+
+bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
+  SmallVector<SUnit *, 10> ExpPipeCands;
+  SmallVector<SUnit *, 10> MFMAPipeCands;
+  SmallVector<SUnit *, 10> MFMAPipeSUs;
+  SmallVector<SUnit *, 10> PackSUs;
+  SmallVector<SUnit *, 10> CvtSUs;
+
+  auto isBitPack = [](unsigned Opc) {
+    return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
+  };
+
+  auto isCvt = [](unsigned Opc) {
+    return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
+  };
+
+  for (SUnit &SU : DAG->SUnits) {
+    auto Opc = SU.getInstr()->getOpcode();
+    if (TII->isTRANS(Opc)) {
+      // Avoid counting a potential bonus V_EXP which all the MFMA depend on
+      if (SU.Succs.size() >= 7)
+        continue;
+      for (auto &Succ : SU.Succs) {
+        if (Succ.getSUnit()->Succs.size() >= 7)
+          continue;
+      }
+      ExpPipeCands.push_back(&SU);
+    }
+
+    if (TII->isMFMAorWMMA(*SU.getInstr()))
+      MFMAPipeCands.push_back(&SU);
+
+    if (isBitPack(Opc))
+      PackSUs.push_back(&SU);
+
+    if (isCvt(Opc))
+      CvtSUs.push_back(&SU);
+  }
+
+  if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
+    return false;
+
+  TransPipeCount = 0;
+
+  std::optional<SUnit *> TempMFMA;
+  std::optional<SUnit *> TempExp;
+  // Count the number of EXPs that reach an MFMA
+  for (auto &PredSU : ExpPipeCands) {
+    for (auto &SuccSU : MFMAPipeCands) {
+      if (DAG->IsReachable(SuccSU, PredSU)) {
+        if (!TempExp.has_value()) {
+          TempExp = PredSU;
+          TempMFMA = SuccSU;
+        }
+        MFMAPipeSUs.push_back(SuccSU);
+        ++TransPipeCount;
+        break;
+      }
+    }
+  }
+
+  if (!TempExp.has_value())
+    return false;
+
+  HasChainBetweenCvt =
+      std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
+                   [&isCvt](SDep &Succ) {
+                     return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
+                   }) == (*TempExp)->Succs.end();
+
+  // Count the number of MFMAs that are reached by an EXP
+  MFMAPipeCount = 0;
+  for (auto &SuccSU : MFMAPipeCands) {
+    if (MFMAPipeSUs.size() &&
+        std::find_if(MFMAPipeSUs.begin(), MFMAPipeSUs.end(),
+                     [&SuccSU](SUnit *PotentialMatch) {
+                       return PotentialMatch->NodeNum == SuccSU->NodeNum;
+                     }) != MFMAPipeSUs.end()) {
+      ++MFMAPipeCount;
+      continue;
+    }
+    for (auto &PredSU : ExpPipeCands) {
+      if (DAG->IsReachable(SuccSU, PredSU)) {
+        MFMAPipeSUs.push_back(SuccSU);
+        ++MFMAPipeCount;
+        break;
+      }
+    }
+  }
+
+  if (!TempMFMA.has_value() || !TempExp.has_value())
+    return false;
+
+  std::optional<SUnit *> TempCvt;
+  for (auto &SuccSU : CvtSUs) {
+    if (DAG->IsReachable(SuccSU, *TempExp)) {
+      TempCvt = SuccSU;
+      break;
+    }
+  }
+
+  HasCvt = false;
+  if (TempCvt.has_value()) {
+    for (auto &SuccSU : MFMAPipeSUs) {
+      if (DAG->IsReachable(SuccSU, *TempCvt)) {
+        HasCvt = true;
+        break;
+      }
+    }
+  }
+
+  MFMAChains = 0;
----------------
jrbyrnes wrote:
Yes it would be nice if we could reduce the amount of DAG processing -- potentially by producing some fields / a class in the Strategy which the rules could access.
https://github.com/llvm/llvm-project/pull/81342
    
    
More information about the llvm-commits
mailing list