[llvm] [AMDGPU] Introduce iglp_opt(2): Generalized exp/mfma interleaving for select kernels (PR #81342)

Austin Kerbow via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 23:53:39 PST 2024


================
@@ -902,6 +904,926 @@ void MFMASmallGemmOpt::applyIGLPStrategy(
         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
   }
+
+  return true;
+}
+
+class MFMAExpInterleaveOpt final : public IGLPStrategy {
+private:
+  // The count of TRANS SUs involved in the interleaved pipeline
+  static unsigned TransPipeCount;
+  // The count of MFMA SUs involved in the interleaved pipeline
+  static unsigned MFMAPipeCount;
+  // The number of transitive MFMA successors for each TRANS SU
+  static unsigned MFMAEnablement;
+  // The number of transitive TRANS predecessors for each MFMA SU
+  static unsigned ExpRequirement;
+  // The count of independent "chains" of MFMA instructions in the pipeline
+  static unsigned MFMAChains;
+  // The length of each independent "chain" of MFMA instructions
+  static unsigned MFMAChainLength;
+  // Whether or not the pipeline has V_CVT instructions
+  static bool HasCvt;
+  // Whether or not there are instructions between the TRANS instruction and
+  // V_CVT
+  static bool HasChainBetweenCvt;
+  // The first occuring DS_READ which feeds an MFMA chain
+  static std::optional<unsigned> FirstPipeDSR;
+  SmallVector<SUnit *, 4> MFMAChainSeeds;
+  // Compute the heuristics for the pipeline, returning whether or not the DAG
+  // is well formatted for the mutation
+  bool analyzeDAG(const SIInstrInfo *TII);
+
+  /// Whether or not the instruction is a transitive predecessor of an MFMA
+  /// instruction
+  class IsPipeExp final : public InstructionRule {
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+
+      if (Cache->empty()) {
+        auto I = DAG->SUnits.rbegin();
+        auto E = DAG->SUnits.rend();
+        for (; I != E; I++) {
+          if (TII->isMFMAorWMMA(*I->getInstr()))
+            Cache->push_back(&*I);
+        }
+        if (Cache->empty())
+          return false;
+      }
+
+      auto Reaches = (std::any_of(
+          Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
+            return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
+          }));
+
+      return Reaches;
+    }
+    IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  /// Whether or not the instruction enables the exact MFMA that is the \p
+  /// Number th MFMA in the chain starting with \p ChainSeed
+  class EnablesNthMFMA final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      bool FoundTrans = false;
+      unsigned Counter = 1;
+      auto DAG = SyncPipe[0].DAG;
+
+      if (Cache->empty()) {
+        auto TII = SyncPipe[0].TII;
+        SmallVector<SUnit *, 8> Worklist;
+
+        auto I = DAG->SUnits.begin();
+        auto E = DAG->SUnits.end();
+        for (; I != E; I++) {
+          if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
+            if (Counter == Number) {
+              Cache->push_back(&*I);
+              break;
+            }
+            ++Counter;
+          }
+          if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
+            FoundTrans = true;
+        }
+      }
+      if (Cache->empty())
+        return false;
+
+      return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+    }
+
+    EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+                   bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+  };
+
+  /// Whether or not the instruction enables  a transitive predecessor of the
+  /// same MFMA instruction as an instruction in a SchedGroup \p Number steps
+  /// before
+  class EnablesNthMFMAInChain final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+    SUnit *ChainSeed;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+
+      if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+        return false;
+
+      if (Cache->empty()) {
+        auto TempSU = ChainSeed;
+        auto Depth = Number;
+        while (Depth > 0) {
+          --Depth;
+          bool Found = false;
+          for (auto &Succ : TempSU->Succs) {
+            if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+              TempSU = Succ.getSUnit();
+              Found = true;
+              break;
+            }
+          }
+          if (!Found)
+            return false;
+        }
+
+        Cache->push_back(TempSU);
+      }
+      if (Cache->empty())
+        return false;
+
+      return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
+    }
+
+    EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
+                          const SIInstrInfo *TII, unsigned SGID,
+                          bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+          ChainSeed(ChainSeed) {}
+  };
+
+  /// Whether or not the instruction has less than \p Size immediate successors.
+  /// If \p HasIntermediary is true, this tests also whether all successors of
+  /// the SUnit have less than \p Size successors.
+  class LessThanNSuccs final : public InstructionRule {
+  private:
+    unsigned Size = 1;
+    bool HasIntermediary = false;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      if (!SyncPipe.size())
+        return false;
+
+      auto SuccSize = std::count_if(
+          SU->Succs.begin(), SU->Succs.end(),
+          [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+      if (SuccSize >= Size)
+        return false;
+
+      if (HasIntermediary) {
+        for (auto Succ : SU->Succs) {
+          auto SuccSize = std::count_if(
+              Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+              [](const SDep &SuccSucc) {
+                return SuccSucc.getKind() == SDep::Data;
+              });
+          if (SuccSize >= Size)
+            return false;
+        }
+      }
+
+      return true;
+    }
+    LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
+                   bool HasIntermediary = false, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+          HasIntermediary(HasIntermediary) {}
+  };
+
+  /// Whether or not the instruction has greater than or equal to \p Size
+  /// immediate successors. If \p HasIntermediary is true, this tests also
+  /// whether all successors of the SUnit have greater than or equal to \p Size
+  /// successors.
+  class GreaterThanNSuccs final : public InstructionRule {
+  private:
+    unsigned Size = 1;
+    bool HasIntermediary = false;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      if (!SyncPipe.size())
+        return false;
+
+      auto SuccSize = std::count_if(
+          SU->Succs.begin(), SU->Succs.end(),
+          [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
+      if (SuccSize >= Size)
+        return true;
+
+      if (HasIntermediary) {
+        for (auto Succ : SU->Succs) {
+          auto SuccSize = std::count_if(
+              Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
+              [](const SDep &SuccSucc) {
+                return SuccSucc.getKind() == SDep::Data;
+              });
+          if (SuccSize >= Size)
+            return true;
+        }
+      }
+
+      return false;
+    }
+    GreaterThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
+                      bool HasIntermediary = false, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Size(Size),
+          HasIntermediary(HasIntermediary) {}
+  };
+
+  // Whether or not the instruction is an V_CVT instruction.
+  class IsCvt final : public InstructionRule {
+  private:
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto Opc = SU->getInstr()->getOpcode();
+      return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
+             Opc == AMDGPU::V_CVT_I32_F32_e32;
+    }
+    IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  // Whether or not the instruction is an V_FMA_F32 instruction.
+  class IsFMA final : public InstructionRule {
+  private:
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
+             SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
+    }
+    IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+  /// Whether or not the instruction is an immediate RAW successor
+  /// of the SchedGroup \p Distance steps before.
+  class IsSuccOfPrevNthGroup final : public InstructionRule {
+  private:
+    unsigned Distance = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      SchedGroup *OtherGroup = nullptr;
+      if (!SyncPipe.size())
+        return false;
+
+      for (auto &PipeSG : SyncPipe) {
+        if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+          OtherGroup = &PipeSG;
+      }
+
+      if (!OtherGroup)
+        return false;
+      if (!OtherGroup->Collection.size())
+        return true;
+
+      for (auto &OtherEle : OtherGroup->Collection) {
+        for (auto &Succ : OtherEle->Succs) {
+          if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
+            return true;
+        }
+      }
+
+      return false;
+    }
+    IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+                         unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+  };
+
+  /// Whether or not the instruction is a transitive successor of any
+  /// instruction the the SchedGroup \p Distance steps before.
+  class IsReachableFromPrevNthGroup final : public InstructionRule {
+  private:
+    unsigned Distance = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      SchedGroup *OtherGroup = nullptr;
+      if (!SyncPipe.size())
+        return false;
+
+      for (auto &PipeSG : SyncPipe) {
+        if ((unsigned)PipeSG.getSGID() == SGID - Distance)
+          OtherGroup = &PipeSG;
+      }
+
+      if (!OtherGroup)
+        return false;
+      if (!OtherGroup->Collection.size())
+        return true;
+
+      auto DAG = SyncPipe[0].DAG;
+
+      for (auto &OtherEle : OtherGroup->Collection)
+        if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
+          return true;
+
+      return false;
+    }
+    IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
+                                unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
+  };
+
+  /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
+  class OccursAfterDSRead final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      return SU->NodeNum >= Number;
+    }
+    OccursAfterDSRead(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
+                      bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
+  };
+
+  /// Whether or not the SU is exactly the \p Number th MFMA in the chain
+  /// starting with \p ChainSeed
+  class IsExactMFMA final : public InstructionRule {
+  private:
+    unsigned Number = 1;
+    SUnit *ChainSeed;
+
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+      auto TII = SyncPipe[0].TII;
+      if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
+        return false;
+
+      if (Cache->empty()) {
+        auto TempSU = ChainSeed;
+        auto Depth = Number;
+        while (Depth > 0) {
+          --Depth;
+          bool Found = false;
+          for (auto &Succ : TempSU->Succs) {
+            if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
+              TempSU = Succ.getSUnit();
+              Found = true;
+              break;
+            }
+          }
+          if (!Found) {
+            return false;
+          }
+        }
+        Cache->push_back(TempSU);
+      }
+
+      if (Cache->empty())
+        return false;
+
+      return (*Cache)[0] == SU;
+    }
+
+    IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
+                unsigned SGID, bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache), Number(Number),
+          ChainSeed(ChainSeed) {}
+  };
+
+  // Whether the instruction occurs after the first TRANS instruction. This
+  // implies the instruction can not be a predecessor of the first TRANS
+  // insruction
+  class OccursAfterExp final : public InstructionRule {
+  public:
+    bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
+               SmallVectorImpl<SchedGroup> &SyncPipe) override {
+
+      SmallVector<SUnit *, 12> Worklist;
+      auto DAG = SyncPipe[0].DAG;
+      auto TII = SyncPipe[0].TII;
+      if (Cache->empty()) {
+        for (auto &SU : DAG->SUnits)
+          if (TII->isTRANS(SU.getInstr()->getOpcode())) {
+            Cache->push_back(&SU);
+            break;
+          }
+      }
+
+      if (Cache->empty())
+        return false;
+
+      return SU->NodeNum > (*Cache)[0]->NodeNum;
+    }
+
+    OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
+                   bool NeedsCache = false)
+        : InstructionRule(TII, SGID, NeedsCache) {}
+  };
+
+public:
+  bool applyIGLPStrategy(
+      DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
+      DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
+      IGLPPhase Phase) override;
+
+  bool shouldApplyStrategy(ScheduleDAGInstrs *DAG, IGLPPhase Phase) override;
+
+  MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
+      : IGLPStrategy(DAG, TII) {
+    IsBottomUp = 0;
+  }
+};
+
+unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
+unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
+unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
+unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
+bool MFMAExpInterleaveOpt::HasCvt = false;
+bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
+std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
+
+bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
+  SmallVector<SUnit *, 10> ExpPipeCands;
+  SmallVector<SUnit *, 10> MFMAPipeCands;
+  SmallVector<SUnit *, 10> MFMAPipeSUs;
+  SmallVector<SUnit *, 10> PackSUs;
+  SmallVector<SUnit *, 10> CvtSUs;
+
+  auto isBitPack = [](unsigned Opc) {
+    return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
+  };
+
+  auto isCvt = [](unsigned Opc) {
+    return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
+  };
+
+  for (SUnit &SU : DAG->SUnits) {
+    auto Opc = SU.getInstr()->getOpcode();
+    if (TII->isTRANS(Opc)) {
+      // Avoid counting a potential bonus V_EXP which all the MFMA depend on
+      if (SU.Succs.size() >= 7)
+        continue;
+      for (auto &Succ : SU.Succs) {
+        if (Succ.getSUnit()->Succs.size() >= 7)
+          continue;
+      }
+      ExpPipeCands.push_back(&SU);
+    }
+
+    if (TII->isMFMAorWMMA(*SU.getInstr()))
+      MFMAPipeCands.push_back(&SU);
+
+    if (isBitPack(Opc))
+      PackSUs.push_back(&SU);
+
+    if (isCvt(Opc))
+      CvtSUs.push_back(&SU);
+  }
+
+  if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
+    return false;
+
+  TransPipeCount = 0;
+
+  std::optional<SUnit *> TempMFMA;
+  std::optional<SUnit *> TempExp;
+  // Count the number of EXPs that reach an MFMA
+  for (auto &PredSU : ExpPipeCands) {
+    for (auto &SuccSU : MFMAPipeCands) {
+      if (DAG->IsReachable(SuccSU, PredSU)) {
+        if (!TempExp.has_value()) {
+          TempExp = PredSU;
+          TempMFMA = SuccSU;
+        }
+        MFMAPipeSUs.push_back(SuccSU);
+        ++TransPipeCount;
+        break;
+      }
+    }
+  }
+
+  if (!TempExp.has_value())
+    return false;
+
+  HasChainBetweenCvt =
+      std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
+                   [&isCvt](SDep &Succ) {
+                     return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
+                   }) == (*TempExp)->Succs.end();
+
+  // Count the number of MFMAs that are reached by an EXP
+  MFMAPipeCount = 0;
+  for (auto &SuccSU : MFMAPipeCands) {
+    if (MFMAPipeSUs.size() &&
+        std::find_if(MFMAPipeSUs.begin(), MFMAPipeSUs.end(),
+                     [&SuccSU](SUnit *PotentialMatch) {
+                       return PotentialMatch->NodeNum == SuccSU->NodeNum;
+                     }) != MFMAPipeSUs.end()) {
+      ++MFMAPipeCount;
+      continue;
+    }
+    for (auto &PredSU : ExpPipeCands) {
+      if (DAG->IsReachable(SuccSU, PredSU)) {
+        MFMAPipeSUs.push_back(SuccSU);
+        ++MFMAPipeCount;
+        break;
+      }
+    }
+  }
+
+  if (!TempMFMA.has_value() || !TempExp.has_value())
+    return false;
+
+  std::optional<SUnit *> TempCvt;
+  for (auto &SuccSU : CvtSUs) {
+    if (DAG->IsReachable(SuccSU, *TempExp)) {
+      TempCvt = SuccSU;
+      break;
+    }
+  }
+
+  HasCvt = false;
+  if (TempCvt.has_value()) {
+    for (auto &SuccSU : MFMAPipeSUs) {
+      if (DAG->IsReachable(SuccSU, *TempCvt)) {
+        HasCvt = true;
+        break;
+      }
+    }
+  }
+
+  MFMAChains = 0;
----------------
kerbowa wrote:

Would be nice to eventually refactor things so that rules can access shared data for the strategy directly, e.g. if you could just feed in the MFMA chains to the rules directly.

https://github.com/llvm/llvm-project/pull/81342


More information about the llvm-commits mailing list