[llvm] c66dee4 - [AMDGPU] Refactor several functions for merging with downstream work. (#110562)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 08:28:58 PDT 2024


Author: Gang Chen
Date: 2024-10-01T08:28:55-07:00
New Revision: c66dee4c6bd650ef20105532a311a95abb25ece5

URL: https://github.com/llvm/llvm-project/commit/c66dee4c6bd650ef20105532a311a95abb25ece5
DIFF: https://github.com/llvm/llvm-project/commit/c66dee4c6bd650ef20105532a311a95abb25ece5.diff

LOG: [AMDGPU] Refactor several functions for merging with downstream work. (#110562)

For setScore, the root function is setScoreByInterval with RegInterval
input
For determineWait, the root function is determineWait with RegInterval
input

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 80a7529002ac90..e64b35d230d486 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -310,7 +310,14 @@ class WaitcntBrackets {
   bool counterOutOfOrder(InstCounterType T) const;
   void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const;
   void simplifyWaitcnt(InstCounterType T, unsigned &Count) const;
-  void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const;
+
+  void determineWait(InstCounterType T, RegInterval Interval,
+                     AMDGPU::Waitcnt &Wait) const;
+  void determineWait(InstCounterType T, int RegNo,
+                     AMDGPU::Waitcnt &Wait) const {
+    determineWait(T, {RegNo, RegNo + 1}, Wait);
+  }
+
   void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
   void applyWaitcnt(InstCounterType T, unsigned Count);
   void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
@@ -345,16 +352,22 @@ class WaitcntBrackets {
     LastFlat[DS_CNT] = ScoreUBs[DS_CNT];
   }
 
-  // Return true if there might be pending writes to the specified vgpr by VMEM
+  // Return true if there might be pending writes to the vgpr-interval by VMEM
   // instructions with types 
diff erent from V.
-  bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const {
-    assert(GprNo < NUM_ALL_VGPRS);
-    return VgprVmemTypes[GprNo] & ~(1 << V);
+  bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const {
+    for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+      assert(RegNo < NUM_ALL_VGPRS);
+      if (VgprVmemTypes[RegNo] & ~(1 << V))
+        return true;
+    }
+    return false;
   }
 
-  void clearVgprVmemTypes(int GprNo) {
-    assert(GprNo < NUM_ALL_VGPRS);
-    VgprVmemTypes[GprNo] = 0;
+  void clearVgprVmemTypes(RegInterval Interval) {
+    for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+      assert(RegNo < NUM_ALL_VGPRS);
+      VgprVmemTypes[RegNo] = 0;
+    }
   }
 
   void setStateOnFunctionEntryOrReturn() {
@@ -396,19 +409,16 @@ class WaitcntBrackets {
   }
 
   void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
-    if (GprNo < NUM_ALL_VGPRS) {
-      VgprUB = std::max(VgprUB, GprNo);
-      VgprScores[T][GprNo] = Val;
-    } else {
-      assert(T == SmemAccessCounter);
-      SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS);
-      SgprScores[GprNo - NUM_ALL_VGPRS] = Val;
-    }
+    setScoreByInterval({GprNo, GprNo + 1}, T, Val);
   }
 
-  void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI,
-                   const MachineRegisterInfo *MRI, const MachineOperand &Op,
-                   unsigned Val);
+  void setScoreByInterval(RegInterval Interval, InstCounterType CntTy,
+                          unsigned Score);
+
+  void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI,
+                         const MachineRegisterInfo *MRI,
+                         const MachineOperand &Op, InstCounterType CntTy,
+                         unsigned Val);
 
   const GCNSubtarget *ST = nullptr;
   InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
@@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
   return Result;
 }
 
-void WaitcntBrackets::setExpScore(const MachineInstr *MI,
-                                  const SIRegisterInfo *TRI,
-                                  const MachineRegisterInfo *MRI,
-                                  const MachineOperand &Op, unsigned Val) {
-  RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
-  assert(TRI->isVectorRegister(*MRI, Op.getReg()));
+void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
+                                         InstCounterType CntTy,
+                                         unsigned Score) {
   for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-    setRegScore(RegNo, EXP_CNT, Val);
+    if (RegNo < NUM_ALL_VGPRS) {
+      VgprUB = std::max(VgprUB, RegNo);
+      VgprScores[CntTy][RegNo] = Score;
+    } else {
+      assert(CntTy == SmemAccessCounter);
+      SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
+      SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
+    }
   }
 }
 
+void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
+                                        const SIRegisterInfo *TRI,
+                                        const MachineRegisterInfo *MRI,
+                                        const MachineOperand &Op,
+                                        InstCounterType CntTy, unsigned Score) {
+  RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
+  setScoreByInterval(Interval, CntTy, Score);
+}
+
 void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
                                     const SIRegisterInfo *TRI,
                                     const MachineRegisterInfo *MRI,
@@ -806,57 +829,61 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
       // All GDS operations must protect their address register (same as
       // export.)
       if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr))
-        setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore);
 
       if (Inst.mayStore()) {
         if (const auto *Data0 =
                 TII->getNamedOperand(Inst, AMDGPU::OpName::data0))
-          setExpScore(&Inst, TRI, MRI, *Data0, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore);
         if (const auto *Data1 =
                 TII->getNamedOperand(Inst, AMDGPU::OpName::data1))
-          setExpScore(&Inst, TRI, MRI, *Data1, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) &&
                  Inst.getOpcode() != AMDGPU::DS_APPEND &&
                  Inst.getOpcode() != AMDGPU::DS_CONSUME &&
                  Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) {
         for (const MachineOperand &Op : Inst.all_uses()) {
           if (TRI->isVectorRegister(*MRI, Op.getReg()))
-            setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+            setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
         }
       }
     } else if (TII->isFLAT(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isMIMG(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+                          CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isMTBUF(Inst)) {
       if (Inst.mayStore())
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+                          CurrScore);
     } else if (TII->isMUBUF(Inst)) {
       if (Inst.mayStore()) {
-        setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+                          CurrScore);
       } else if (SIInstrInfo::isAtomicRet(Inst)) {
-        setExpScore(&Inst, TRI, MRI,
-                    *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
-                    CurrScore);
+        setScoreByOperand(&Inst, TRI, MRI,
+                          *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+                          EXP_CNT, CurrScore);
       }
     } else if (TII->isLDSDIR(Inst)) {
       // LDSDIR instructions attach the score to the destination.
-      setExpScore(&Inst, TRI, MRI,
-                  *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore);
+      setScoreByOperand(&Inst, TRI, MRI,
+                        *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst),
+                        EXP_CNT, CurrScore);
     } else {
       if (TII->isEXP(Inst)) {
         // For export the destination registers are really temps that
@@ -865,15 +892,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
         // score.
         for (MachineOperand &DefMO : Inst.all_defs()) {
           if (TRI->isVGPR(*MRI, DefMO.getReg())) {
-            setRegScore(
-                TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)),
-                EXP_CNT, CurrScore);
+            setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore);
           }
         }
       }
       for (const MachineOperand &Op : Inst.all_uses()) {
         if (TRI->isVectorRegister(*MRI, Op.getReg()))
-          setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+          setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
       }
     }
   } else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
@@ -901,9 +926,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
             VgprVmemTypes[RegNo] |= 1 << V;
         }
       }
-      for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-        setRegScore(RegNo, T, CurrScore);
-      }
+      setScoreByInterval(Interval, T, CurrScore);
     }
     if (Inst.mayStore() &&
         (TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
@@ -1034,31 +1057,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
     Count = ~0u;
 }
 
-void WaitcntBrackets::determineWait(InstCounterType T, int RegNo,
+void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
                                     AMDGPU::Waitcnt &Wait) const {
-  unsigned ScoreToWait = getRegScore(RegNo, T);
-
-  // If the score of src_operand falls within the bracket, we need an
-  // s_waitcnt instruction.
   const unsigned LB = getScoreLB(T);
   const unsigned UB = getScoreUB(T);
-  if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
-    if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
-        !ST->hasFlatLgkmVMemCountInOrder()) {
-      // If there is a pending FLAT operation, and this is a VMem or LGKM
-      // waitcnt and the target can report early completion, then we need
-      // to force a waitcnt 0.
-      addWait(Wait, T, 0);
-    } else if (counterOutOfOrder(T)) {
-      // Counter can get decremented out-of-order when there
-      // are multiple types event in the bracket. Also emit an s_wait counter
-      // with a conservative value of 0 for the counter.
-      addWait(Wait, T, 0);
-    } else {
-      // If a counter has been maxed out avoid overflow by waiting for
-      // MAX(CounterType) - 1 instead.
-      unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
-      addWait(Wait, T, NeededWait);
+  for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+    unsigned ScoreToWait = getRegScore(RegNo, T);
+
+    // If the score of src_operand falls within the bracket, we need an
+    // s_waitcnt instruction.
+    if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
+      if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
+          !ST->hasFlatLgkmVMemCountInOrder()) {
+        // If there is a pending FLAT operation, and this is a VMem or LGKM
+        // waitcnt and the target can report early completion, then we need
+        // to force a waitcnt 0.
+        addWait(Wait, T, 0);
+      } else if (counterOutOfOrder(T)) {
+        // Counter can get decremented out-of-order when there
+        // are multiple types event in the bracket. Also emit an s_wait counter
+        // with a conservative value of 0 for the counter.
+        addWait(Wait, T, 0);
+      } else {
+        // If a counter has been maxed out avoid overflow by waiting for
+        // MAX(CounterType) - 1 instead.
+        unsigned NeededWait =
+            std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
+        addWait(Wait, T, NeededWait);
+      }
     }
   }
 }
@@ -1670,18 +1696,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         RegInterval CallAddrOpInterval =
             ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp);
 
-        for (int RegNo = CallAddrOpInterval.first;
-             RegNo < CallAddrOpInterval.second; ++RegNo)
-          ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+        ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval,
+                                    Wait);
 
         if (const auto *RtnAddrOp =
                 TII->getNamedOperand(MI, AMDGPU::OpName::dst)) {
           RegInterval RtnAddrOpInterval =
               ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp);
 
-          for (int RegNo = RtnAddrOpInterval.first;
-               RegNo < RtnAddrOpInterval.second; ++RegNo)
-            ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+          ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval,
+                                      Wait);
         }
       }
     } else {
@@ -1750,36 +1774,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op);
 
         const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
-        for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-          if (IsVGPR) {
-            // Implicit VGPR defs and uses are never a part of the memory
-            // instructions description and usually present to account for
-            // super-register liveness.
-            // TODO: Most of the other instructions also have implicit uses
-            // for the liveness accounting only.
-            if (Op.isImplicit() && MI.mayLoadOrStore())
-              continue;
-
-            // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
-            // previous write and this write are the same type of VMEM
-            // instruction, in which case they are (in some architectures)
-            // guaranteed to write their results in order anyway.
-            if (Op.isUse() || !updateVMCntOnly(MI) ||
-                ScoreBrackets.hasOtherPendingVmemTypes(RegNo,
-                                                       getVmemType(MI)) ||
-                !ST->hasVmemWriteVgprInOrder()) {
-              ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait);
-              ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait);
-              ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait);
-              ScoreBrackets.clearVgprVmemTypes(RegNo);
-            }
-            if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
-              ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
-            }
-            ScoreBrackets.determineWait(DS_CNT, RegNo, Wait);
-          } else {
-            ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+        if (IsVGPR) {
+          // Implicit VGPR defs and uses are never a part of the memory
+          // instructions description and usually present to account for
+          // super-register liveness.
+          // TODO: Most of the other instructions also have implicit uses
+          // for the liveness accounting only.
+          if (Op.isImplicit() && MI.mayLoadOrStore())
+            continue;
+
+          // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
+          // previous write and this write are the same type of VMEM
+          // instruction, in which case they are (in some architectures)
+          // guaranteed to write their results in order anyway.
+          if (Op.isUse() || !updateVMCntOnly(MI) ||
+              ScoreBrackets.hasOtherPendingVmemTypes(Interval,
+                                                     getVmemType(MI)) ||
+              !ST->hasVmemWriteVgprInOrder()) {
+            ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait);
+            ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait);
+            ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
+            ScoreBrackets.clearVgprVmemTypes(Interval);
+          }
+          if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
+            ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
           }
+          ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
+        } else {
+          ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
         }
       }
     }


        


More information about the llvm-commits mailing list