[llvm] 021def6 - [AMDGPU] Use alias info to relax waitcounts for LDS DMA (#74537)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 17 23:44:20 PST 2024


Author: Stanislav Mekhanoshin
Date: 2024-01-17T23:44:15-08:00
New Revision: 021def6c2278fd932d18b4d891c2e75c1d8e6f1d

URL: https://github.com/llvm/llvm-project/commit/021def6c2278fd932d18b4d891c2e75c1d8e6f1d
DIFF: https://github.com/llvm/llvm-project/commit/021def6c2278fd932d18b4d891c2e75c1d8e6f1d.diff

LOG: [AMDGPU] Use alias info to relax waitcounts for LDS DMA (#74537)

LDA DMA loads increase VMCNT and a load from the LDS stored must wait on
this counter to only read memory after it is written. Wait count
insertion pass does not track memory dependencies, it tracks register
dependencies. To model the LDS dependency a pseudo register is used in
the scoreboard, acting like if LDS DMA writes it and LDS load reads it.

This patch adds 8 more pseudo registers to use for independent LDS
locations if we can prove they are disjoint using alias analysis.

Fixes: SWDEV-433427

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
    llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 2ae028477ac7e3..62c977fc96a89d 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/CodeGen/MachineLoopInfo.h"
 #include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/InitializePasses.h"
@@ -121,8 +122,13 @@ enum RegisterMapping {
   SQ_MAX_PGM_VGPRS = 512, // Maximum programmable VGPRs across all targets.
   AGPR_OFFSET = 256,      // Maximum programmable ArchVGPRs across all targets.
   SQ_MAX_PGM_SGPRS = 256, // Maximum programmable SGPRs across all targets.
-  NUM_EXTRA_VGPRS = 1,    // A reserved slot for DS.
-  EXTRA_VGPR_LDS = 0,     // An artificial register to track LDS writes.
+  NUM_EXTRA_VGPRS = 9,    // Reserved slots for DS.
+  // Artificial register slots to track LDS writes into specific LDS locations
+  // if a location is known. When slots are exhausted or location is
+  // unknown use the first slot. The first slot is also always updated in
+  // addition to known location's slot to properly generate waits if dependent
+  // instruction's location is unknown.
+  EXTRA_VGPR_LDS = 0,
   NUM_ALL_VGPRS = SQ_MAX_PGM_VGPRS + NUM_EXTRA_VGPRS, // Where SGPR starts.
 };
 
@@ -297,6 +303,10 @@ class WaitcntBrackets {
     PendingEvents |= WaitEventMaskForInst[VS_CNT];
   }
 
+  ArrayRef<const MachineInstr *> getLDSDMAStores() const {
+    return LDSDMAStores;
+  }
+
   void print(raw_ostream &);
   void dump() { print(dbgs()); }
 
@@ -359,6 +369,9 @@ class WaitcntBrackets {
   // Bitmask of the VmemTypes of VMEM instructions that might have a pending
   // write to each vgpr.
   unsigned char VgprVmemTypes[NUM_ALL_VGPRS] = {0};
+  // Store representative LDS DMA operations. The only useful info here is
+  // alias info. One store is kept per unique AAInfo.
+  SmallVector<const MachineInstr *, NUM_EXTRA_VGPRS - 1> LDSDMAStores;
 };
 
 class SIInsertWaitcnts : public MachineFunctionPass {
@@ -373,6 +386,7 @@ class SIInsertWaitcnts : public MachineFunctionPass {
   DenseMap<MachineBasicBlock *, bool> PreheadersToFlush;
   MachineLoopInfo *MLI;
   MachinePostDominatorTree *PDT;
+  AliasAnalysis *AA = nullptr;
 
   struct BlockInfo {
     std::unique_ptr<WaitcntBrackets> Incoming;
@@ -415,6 +429,8 @@ class SIInsertWaitcnts : public MachineFunctionPass {
     AU.setPreservesCFG();
     AU.addRequired<MachineLoopInfo>();
     AU.addRequired<MachinePostDominatorTree>();
+    AU.addUsedIfAvailable<AAResultsWrapperPass>();
+    AU.addPreserved<AAResultsWrapperPass>();
     MachineFunctionPass::getAnalysisUsage(AU);
   }
 
@@ -707,7 +723,40 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
         (TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
       // MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS
       // written can be accessed. A load from LDS to VMEM does not need a wait.
-      setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
+      unsigned Slot = 0;
+      for (const auto *MemOp : Inst.memoperands()) {
+        if (!MemOp->isStore() ||
+            MemOp->getAddrSpace() != AMDGPUAS::LOCAL_ADDRESS)
+          continue;
+        // Comparing just AA info does not guarantee memoperands are equal
+        // in general, but this is so for LDS DMA in practice.
+        auto AAI = MemOp->getAAInfo();
+        // Alias scope information gives a way to definitely identify an
+        // original memory object and practically produced in the module LDS
+        // lowering pass. If there is no scope available we will not be able
+        // to disambiguate LDS aliasing as after the module lowering all LDS
+        // is squashed into a single big object. Do not attempt to use one of
+        // the limited LDSDMAStores for something we will not be able to use
+        // anyway.
+        if (!AAI || !AAI.Scope)
+          break;
+        for (unsigned I = 0, E = LDSDMAStores.size(); I != E && !Slot; ++I) {
+          for (const auto *MemOp : LDSDMAStores[I]->memoperands()) {
+            if (MemOp->isStore() && AAI == MemOp->getAAInfo()) {
+              Slot = I + 1;
+              break;
+            }
+          }
+        }
+        if (Slot || LDSDMAStores.size() == NUM_EXTRA_VGPRS - 1)
+          break;
+        LDSDMAStores.push_back(&Inst);
+        Slot = LDSDMAStores.size();
+        break;
+      }
+      setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS + Slot, T, CurrScore);
+      if (Slot)
+        setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore);
     }
   }
 }
@@ -1183,9 +1232,27 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         // No need to wait before load from VMEM to LDS.
         if (TII->mayWriteLDSThroughDMA(MI))
           continue;
-        unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
+
         // VM_CNT is only relevant to vgpr or LDS.
-        ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
+        unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
+        bool FoundAliasingStore = false;
+        // Only objects with alias scope info were added to LDSDMAScopes array.
+        // In the absense of the scope info we will not be able to disambiguate
+        // aliasing here. There is no need to try searching for a corresponding
+        // store slot. This is conservatively correct because in that case we
+        // will produce a wait using the first (general) LDS DMA wait slot which
+        // will wait on all of them anyway.
+        if (Ptr && Memop->getAAInfo() && Memop->getAAInfo().Scope) {
+          const auto &LDSDMAStores = ScoreBrackets.getLDSDMAStores();
+          for (unsigned I = 0, E = LDSDMAStores.size(); I != E; ++I) {
+            if (MI.mayAlias(AA, *LDSDMAStores[I], true)) {
+              FoundAliasingStore = true;
+              ScoreBrackets.determineWait(VM_CNT, RegNo + I + 1, Wait);
+            }
+          }
+        }
+        if (!FoundAliasingStore)
+          ScoreBrackets.determineWait(VM_CNT, RegNo, Wait);
         if (Memop->isStore()) {
           ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
         }
@@ -1834,6 +1901,8 @@ bool SIInsertWaitcnts::runOnMachineFunction(MachineFunction &MF) {
   const SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
   MLI = &getAnalysis<MachineLoopInfo>();
   PDT = &getAnalysis<MachinePostDominatorTree>();
+  if (auto AAR = getAnalysisIfAvailable<AAResultsWrapperPass>())
+    AA = &AAR->getAAResults();
 
   ForceEmitZeroWaitcnts = ForceEmitZeroFlag;
   for (auto T : inst_counter_types())

diff  --git a/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll b/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll
index 1d6968a86a2e25..5cb3ca0b80b668 100644
--- a/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll
+++ b/llvm/test/CodeGen/AMDGPU/lds-dma-waits.ll
@@ -3,21 +3,23 @@
 
 @lds.0 = internal addrspace(3) global [64 x float] poison, align 16
 @lds.1 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.2 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.3 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.4 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.5 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.6 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.7 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.8 = internal addrspace(3) global [64 x float] poison, align 16
+ at lds.9 = internal addrspace(3) global [64 x float] poison, align 16
 
 declare void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) nocapture, i32 %size, i32 %voffset, i32 %soffset, i32 %offset, i32 %aux)
 declare void @llvm.amdgcn.global.load.lds(ptr addrspace(1) nocapture %gptr, ptr addrspace(3) nocapture %lptr, i32 %size, i32 %offset, i32 %aux)
 
-; FIXME: vmcnt(0) is too strong, it shall use vmcnt(2) before the first
-;        ds_read_b32 and vmcnt(0) before the second.
-
 ; GCN-LABEL: {{^}}buffer_load_lds_dword_2_arrays:
 ; GCN-COUNT-4: buffer_load_dword
-; GCN: s_waitcnt vmcnt(0)
+; GCN: s_waitcnt vmcnt(2)
 ; GCN: ds_read_b32
-
-; FIXME:
-; GCN-NOT: s_waitcnt
-
+; GCN: s_waitcnt vmcnt(0)
 ; GCN: ds_read_b32
 define amdgpu_kernel void @buffer_load_lds_dword_2_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
 main_body:
@@ -43,15 +45,9 @@ main_body:
 ; GCN-COUNT-4: global_load_dword
 ; GFX9: s_waitcnt vmcnt(0)
 ; GFX9-COUNT-2: ds_read_b32
-
-; FIXME: can be vmcnt(2)
-
-; GFX10: s_waitcnt vmcnt(0)
+; GFX10: s_waitcnt vmcnt(2)
 ; GFX10: ds_read_b32
-
-; FIXME:
-; GFX10-NOT: s_waitcnt
-
+; GFX10: s_waitcnt vmcnt(0)
 ; GFX10: ds_read_b32
 define amdgpu_kernel void @global_load_lds_dword_2_arrays(ptr addrspace(1) nocapture %gptr, i32 %i1, i32 %i2, ptr addrspace(1) %out) {
 main_body:
@@ -70,4 +66,89 @@ main_body:
   ret void
 }
 
+; There are 8 pseudo registers defined to track LDS DMA dependencies.
+; When exhausted we default to vmcnt(0).
+
+; GCN-LABEL: {{^}}buffer_load_lds_dword_10_arrays:
+; GCN-COUNT-10: buffer_load_dword
+; GCN: s_waitcnt vmcnt(8)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(7)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(6)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(5)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(4)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(3)
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(2)
+; GCN-NOT: s_waitcnt vmcnt
+; GCN: ds_read_b32
+; GCN: s_waitcnt vmcnt(0)
+; GCN: ds_read_b32
+define amdgpu_kernel void @buffer_load_lds_dword_10_arrays(<4 x i32> %rsrc, i32 %i1, i32 %i2, i32 %i3, i32 %i4, i32 %i5, i32 %i6, i32 %i7, i32 %i8, i32 %i9, ptr addrspace(1) %out) {
+main_body:
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.0, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.1, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.2, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.3, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.4, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.5, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.6, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.7, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.8, i32 4, i32 0, i32 0, i32 0, i32 0)
+  call void @llvm.amdgcn.raw.buffer.load.lds(<4 x i32> %rsrc, ptr addrspace(3) @lds.9, i32 4, i32 0, i32 0, i32 0, i32 0)
+  %gep.0 = getelementptr float, ptr addrspace(3) @lds.0, i32 %i1
+  %gep.1 = getelementptr float, ptr addrspace(3) @lds.1, i32 %i2
+  %gep.2 = getelementptr float, ptr addrspace(3) @lds.2, i32 %i2
+  %gep.3 = getelementptr float, ptr addrspace(3) @lds.3, i32 %i2
+  %gep.4 = getelementptr float, ptr addrspace(3) @lds.4, i32 %i2
+  %gep.5 = getelementptr float, ptr addrspace(3) @lds.5, i32 %i2
+  %gep.6 = getelementptr float, ptr addrspace(3) @lds.6, i32 %i2
+  %gep.7 = getelementptr float, ptr addrspace(3) @lds.7, i32 %i2
+  %gep.8 = getelementptr float, ptr addrspace(3) @lds.8, i32 %i2
+  %gep.9 = getelementptr float, ptr addrspace(3) @lds.9, i32 %i2
+  %val.0 = load float, ptr addrspace(3) %gep.0, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.1 = load float, ptr addrspace(3) %gep.1, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.2 = load float, ptr addrspace(3) %gep.2, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.3 = load float, ptr addrspace(3) %gep.3, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.4 = load float, ptr addrspace(3) %gep.4, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.5 = load float, ptr addrspace(3) %gep.5, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.6 = load float, ptr addrspace(3) %gep.6, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.7 = load float, ptr addrspace(3) %gep.7, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.8 = load float, ptr addrspace(3) %gep.8, align 4
+  call void @llvm.amdgcn.wave.barrier()
+  %val.9 = load float, ptr addrspace(3) %gep.9, align 4
+  %out.gep.1 = getelementptr float, ptr addrspace(1) %out, i32 1
+  %out.gep.2 = getelementptr float, ptr addrspace(1) %out, i32 2
+  %out.gep.3 = getelementptr float, ptr addrspace(1) %out, i32 3
+  %out.gep.4 = getelementptr float, ptr addrspace(1) %out, i32 4
+  %out.gep.5 = getelementptr float, ptr addrspace(1) %out, i32 5
+  %out.gep.6 = getelementptr float, ptr addrspace(1) %out, i32 6
+  %out.gep.7 = getelementptr float, ptr addrspace(1) %out, i32 7
+  %out.gep.8 = getelementptr float, ptr addrspace(1) %out, i32 8
+  %out.gep.9 = getelementptr float, ptr addrspace(1) %out, i32 9
+  store float %val.0, ptr addrspace(1) %out
+  store float %val.1, ptr addrspace(1) %out.gep.1
+  store float %val.2, ptr addrspace(1) %out.gep.2
+  store float %val.3, ptr addrspace(1) %out.gep.3
+  store float %val.4, ptr addrspace(1) %out.gep.4
+  store float %val.5, ptr addrspace(1) %out.gep.5
+  store float %val.6, ptr addrspace(1) %out.gep.6
+  store float %val.7, ptr addrspace(1) %out.gep.7
+  store float %val.8, ptr addrspace(1) %out.gep.8
+  store float %val.9, ptr addrspace(1) %out.gep.9
+  ret void
+}
+
 declare void @llvm.amdgcn.wave.barrier()


        


More information about the llvm-commits mailing list