[llvm] [AMDGPU] Flatten recursive register resource info propagation (PR #142766)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 4 04:23:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Janek van Oirschot (JanekvO)

<details>
<summary>Changes</summary>

In #<!-- -->112251 I had mentioned I'd follow up with flattening of recursion for register resource info propagation

Behaviour prior to this patch when a recursive call is used is to take the module scope worst case function register use (even prior to AMDGPUMCResourceInfo). With this patch it will, when a cycle is detected, attempt to do a simple cycle avoidant dfs to find the worst case constant within the cycle and the cycle's propagates. In other words, it will attempt to look for the cycle scope worst case rather than module scope worst case.

---
Full diff: https://github.com/llvm/llvm-project/pull/142766.diff


4 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp (+88-13) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.h (+6) 
- (modified) llvm/test/CodeGen/AMDGPU/function-resource-usage.ll (+16-16) 
- (modified) llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll (+79-3) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index 7d2596d666185..37a3b99baa2ac 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -97,6 +97,87 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
   return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
 }
 
+// Tries to flatten recursive call register resource gathering. Simple cycle
+// avoiding dfs to find the constants in the propagated symbols.
+// Assumes:
+// - RecSym has been confirmed to recurse (this means the callee symbols should
+//   all be populated, started at RecSym).
+// - Shape of the resource symbol's MCExpr (`max` args are order agnostic):
+//   RecSym.MCExpr := max(<constant>+, <callee_symbol>*)
+const MCExpr *MCResourceInfo::flattenedCycleMax(MCSymbol *RecSym,
+                                                ResourceInfoKind RIK,
+                                                MCContext &OutContext) {
+  SmallPtrSet<const MCExpr *, 8> Seen;
+  SmallVector<const MCExpr *, 8> WorkList;
+  int64_t Maximum = 0;
+
+  const MCExpr *RecExpr = RecSym->getVariableValue();
+  WorkList.push_back(RecExpr);
+
+  while (!WorkList.empty()) {
+    const MCExpr *CurExpr = WorkList.pop_back_val();
+    switch (CurExpr->getKind()) {
+    default: {
+      // Assuming the recursion is of shape `max(<constant>, <callee_symbol>)`
+      // where <callee_symbol> will eventually recurse. If this condition holds,
+      // the recursion occurs within some other (possibly unresolvable) MCExpr,
+      // thus using the worst case value then.
+      if (CurExpr->isSymbolUsedInExpression(RecSym)) {
+        LLVM_DEBUG(dbgs() << "MCResUse:   " << RecSym->getName()
+                          << ": Recursion in unexpected sub-expression, using "
+                             "module maximum\n");
+        switch (RIK) {
+        default:
+          break;
+        case RIK_NumVGPR:
+          return MCSymbolRefExpr::create(getMaxVGPRSymbol(OutContext),
+                                         OutContext);
+          break;
+        case RIK_NumSGPR:
+          return MCSymbolRefExpr::create(getMaxSGPRSymbol(OutContext),
+                                         OutContext);
+          break;
+        case RIK_NumAGPR:
+          return MCSymbolRefExpr::create(getMaxAGPRSymbol(OutContext),
+                                         OutContext);
+          break;
+        }
+      }
+      break;
+    }
+    case MCExpr::ExprKind::Constant: {
+      int64_t Val = cast<MCConstantExpr>(CurExpr)->getValue();
+      Maximum = std::max(Maximum, Val);
+      break;
+    }
+    case MCExpr::ExprKind::SymbolRef: {
+      const MCSymbolRefExpr *SymExpr = cast<MCSymbolRefExpr>(CurExpr);
+      const MCSymbol &SymRef = SymExpr->getSymbol();
+      if (SymRef.isVariable()) {
+        const MCExpr *SymVal = SymRef.getVariableValue();
+        auto [_, IsSeen] = Seen.insert(SymVal);
+        if (IsSeen)
+          WorkList.push_back(SymVal);
+      }
+      break;
+    }
+    case MCExpr::ExprKind::Target: {
+      const AMDGPUMCExpr *TargetExpr = cast<AMDGPUMCExpr>(CurExpr);
+      if (TargetExpr->getKind() == AMDGPUMCExpr::VariantKind::AGVK_Max) {
+        for (auto &Arg : TargetExpr->getArgs())
+          WorkList.push_back(Arg);
+      }
+      break;
+    }
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "MCResUse:   " << RecSym->getName()
+                    << ": Using flattened max: << " << Maximum << '\n');
+
+  return MCConstantExpr::create(Maximum, OutContext);
+}
+
 void MCResourceInfo::assignResourceInfoExpr(
     int64_t LocalValue, ResourceInfoKind RIK, AMDGPUMCExpr::VariantKind Kind,
     const MachineFunction &MF, const SmallVectorImpl<const Function *> &Callees,
@@ -132,25 +213,19 @@ void MCResourceInfo::assignResourceInfoExpr(
                           << CalleeValSym->getName() << " as callee\n");
         ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
       } else {
-        LLVM_DEBUG(
-            dbgs() << "MCResUse:   " << Sym->getName()
-                   << ": Recursion found, falling back to module maximum\n");
-        // In case of recursion: make sure to use conservative register counts
-        // (i.e., specifically for VGPR/SGPR/AGPR).
+        LLVM_DEBUG(dbgs() << "MCResUse:   " << Sym->getName()
+                          << ": Recursion found, attempt flattening of cycle "
+                             "for resource usage\n");
+        // In case of recursion for vgpr/sgpr/agpr resource usage: try to
+        // flatten and use the max of the call cycle. May still end up emitting
+        // module max if not fully resolvable.
         switch (RIK) {
         default:
           break;
         case RIK_NumVGPR:
-          ArgExprs.push_back(MCSymbolRefExpr::create(
-              getMaxVGPRSymbol(OutContext), OutContext));
-          break;
         case RIK_NumSGPR:
-          ArgExprs.push_back(MCSymbolRefExpr::create(
-              getMaxSGPRSymbol(OutContext), OutContext));
-          break;
         case RIK_NumAGPR:
-          ArgExprs.push_back(MCSymbolRefExpr::create(
-              getMaxAGPRSymbol(OutContext), OutContext));
+          ArgExprs.push_back(flattenedCycleMax(CalleeValSym, RIK, OutContext));
           break;
         }
       }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.h
index a670878948c31..fa98f82d11022 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.h
@@ -58,6 +58,12 @@ class MCResourceInfo {
   // Assigns expression for Max S/V/A-GPRs to the referenced symbols.
   void assignMaxRegs(MCContext &OutContext);
 
+  // Take flattened max of cyclic function calls' knowns. For example, for
+  // a cycle A->B->C->D->A, take max(A, B, C, D) for A and have B, C, D have the
+  // propgated value from A.
+  const MCExpr *flattenedCycleMax(MCSymbol *RecSym, ResourceInfoKind RIK,
+                                  MCContext &OutContext);
+
 public:
   MCResourceInfo() = default;
   void addMaxVGPRCandidate(int32_t candidate) {
diff --git a/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll b/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
index 0a6aa05c2d212..2a18d40e0bd8a 100644
--- a/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
+++ b/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
@@ -495,17 +495,17 @@ define amdgpu_kernel void @usage_direct_recursion(i32 %n) #0 {
 ; GCN: NumVgprs: max(43, multi_stage_recurse1.num_vgpr)
 ; GCN: ScratchSize: 16+max(multi_stage_recurse1.private_seg_size)
 ; GCN-LABEL: {{^}}multi_stage_recurse1:
-; GCN: .set multi_stage_recurse1.num_vgpr, max(48, amdgpu.max_num_vgpr)
-; GCN: .set multi_stage_recurse1.num_agpr, max(0, amdgpu.max_num_agpr)
-; GCN: .set multi_stage_recurse1.numbered_sgpr, max(34, amdgpu.max_num_sgpr)
+; GCN: .set multi_stage_recurse1.num_vgpr, max(48, 43)
+; GCN: .set multi_stage_recurse1.num_agpr, max(0, 0)
+; GCN: .set multi_stage_recurse1.numbered_sgpr, max(34, 34)
 ; GCN: .set multi_stage_recurse1.private_seg_size, 16
 ; GCN: .set multi_stage_recurse1.uses_vcc, 1
 ; GCN: .set multi_stage_recurse1.uses_flat_scratch, 0
 ; GCN: .set multi_stage_recurse1.has_dyn_sized_stack, 0
 ; GCN: .set multi_stage_recurse1.has_recursion, 1
 ; GCN: .set multi_stage_recurse1.has_indirect_call, 0
-; GCN: TotalNumSgprs: multi_stage_recurse1.numbered_sgpr+4
-; GCN: NumVgprs: max(48, amdgpu.max_num_vgpr)
+; GCN: TotalNumSgprs: 38
+; GCN: NumVgprs: 48
 ; GCN: ScratchSize: 16
 define void @multi_stage_recurse1(i32 %val) #2 {
   call void @multi_stage_recurse2(i32 %val)
@@ -528,8 +528,8 @@ define void @multi_stage_recurse2(i32 %val) #2 {
 ; GCN: .set usage_multi_stage_recurse.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
 ; GCN: .set usage_multi_stage_recurse.has_recursion, or(1, multi_stage_recurse1.has_recursion)
 ; GCN: .set usage_multi_stage_recurse.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
-; GCN: TotalNumSgprs: usage_multi_stage_recurse.numbered_sgpr+6
-; GCN: NumVgprs: usage_multi_stage_recurse.num_vgpr
+; GCN: TotalNumSgprs: 40
+; GCN: NumVgprs: 48
 ; GCN: ScratchSize: 16
 define amdgpu_kernel void @usage_multi_stage_recurse(i32 %n) #0 {
   call void @multi_stage_recurse1(i32 %n)
@@ -550,17 +550,17 @@ define amdgpu_kernel void @usage_multi_stage_recurse(i32 %n) #0 {
 ; GCN: NumVgprs: max(41, multi_stage_recurse_noattr1.num_vgpr)
 ; GCN: ScratchSize: 16+max(multi_stage_recurse_noattr1.private_seg_size)
 ; GCN-LABEL: {{^}}multi_stage_recurse_noattr1:
-; GCN: .set multi_stage_recurse_noattr1.num_vgpr, max(41, amdgpu.max_num_vgpr)
-; GCN: .set multi_stage_recurse_noattr1.num_agpr, max(0, amdgpu.max_num_agpr)
-; GCN: .set multi_stage_recurse_noattr1.numbered_sgpr, max(57, amdgpu.max_num_sgpr)
+; GCN: .set multi_stage_recurse_noattr1.num_vgpr, max(41, 41)
+; GCN: .set multi_stage_recurse_noattr1.num_agpr, max(0, 0)
+; GCN: .set multi_stage_recurse_noattr1.numbered_sgpr, max(57, 54)
 ; GCN: .set multi_stage_recurse_noattr1.private_seg_size, 16
 ; GCN: .set multi_stage_recurse_noattr1.uses_vcc, 1
 ; GCN: .set multi_stage_recurse_noattr1.uses_flat_scratch, 0
 ; GCN: .set multi_stage_recurse_noattr1.has_dyn_sized_stack, 0
 ; GCN: .set multi_stage_recurse_noattr1.has_recursion, 0
 ; GCN: .set multi_stage_recurse_noattr1.has_indirect_call, 0
-; GCN: TotalNumSgprs: multi_stage_recurse_noattr1.numbered_sgpr+4
-; GCN: NumVgprs: max(41, amdgpu.max_num_vgpr)
+; GCN: TotalNumSgprs: 61
+; GCN: NumVgprs: 41
 ; GCN: ScratchSize: 16
 define void @multi_stage_recurse_noattr1(i32 %val) #0 {
   call void @multi_stage_recurse_noattr2(i32 %val)
@@ -583,8 +583,8 @@ define void @multi_stage_recurse_noattr2(i32 %val) #0 {
 ; GCN: .set usage_multi_stage_recurse_noattrs.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
 ; GCN: .set usage_multi_stage_recurse_noattrs.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
 ; GCN: .set usage_multi_stage_recurse_noattrs.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
-; GCN: TotalNumSgprs: usage_multi_stage_recurse_noattrs.numbered_sgpr+6
-; GCN: NumVgprs: usage_multi_stage_recurse_noattrs.num_vgpr
+; GCN: TotalNumSgprs: 63
+; GCN: NumVgprs: 41
 ; GCN: ScratchSize: 16
 define amdgpu_kernel void @usage_multi_stage_recurse_noattrs(i32 %n) #0 {
   call void @multi_stage_recurse_noattr1(i32 %n)
@@ -601,8 +601,8 @@ define amdgpu_kernel void @usage_multi_stage_recurse_noattrs(i32 %n) #0 {
 ; GCN:  .set multi_call_with_multi_stage_recurse.has_dyn_sized_stack, or(0, use_stack0.has_dyn_sized_stack, use_stack1.has_dyn_sized_stack, multi_stage_recurse1.has_dyn_sized_stack)
 ; GCN:  .set multi_call_with_multi_stage_recurse.has_recursion, or(1, use_stack0.has_recursion, use_stack1.has_recursion, multi_stage_recurse1.has_recursion)
 ; GCN:  .set multi_call_with_multi_stage_recurse.has_indirect_call, or(0, use_stack0.has_indirect_call, use_stack1.has_indirect_call, multi_stage_recurse1.has_indirect_call)
-; GCN: TotalNumSgprs: multi_call_with_multi_stage_recurse.numbered_sgpr+6
-; GCN: NumVgprs:  multi_call_with_multi_stage_recurse.num_vgpr
+; GCN: TotalNumSgprs: 59
+; GCN: NumVgprs:  48
 ; GCN: ScratchSize: 2052
 define amdgpu_kernel void @multi_call_with_multi_stage_recurse(i32 %n) #0 {
   call void @use_stack0()
diff --git a/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll b/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
index 3093349bff37c..a41a06592f62f 100644
--- a/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
+++ b/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
@@ -1,5 +1,7 @@
 ; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck %s
 
+; Recursion: foo -> bar -> baz -> qux -> foo
+
 ; CHECK-LABEL: {{^}}qux
 ; CHECK: .set qux.num_vgpr, max(71, foo.num_vgpr)
 ; CHECK: .set qux.num_agpr, max(0, foo.num_agpr)
@@ -34,9 +36,9 @@
 ; CHECK: .set bar.has_indirect_call, or(0, baz.has_indirect_call)
 
 ; CHECK-LABEL: {{^}}foo
-; CHECK: .set foo.num_vgpr, max(46, amdgpu.max_num_vgpr)
-; CHECK: .set foo.num_agpr, max(0, amdgpu.max_num_agpr)
-; CHECK: .set foo.numbered_sgpr, max(71, amdgpu.max_num_sgpr)
+; CHECK: .set foo.num_vgpr, max(46, 71)
+; CHECK: .set foo.num_agpr, max(0, 0)
+; CHECK: .set foo.numbered_sgpr, max(71, 61)
 ; CHECK: .set foo.private_seg_size, 16
 ; CHECK: .set foo.uses_vcc, 1
 ; CHECK: .set foo.uses_flat_scratch, 0
@@ -91,3 +93,77 @@ define amdgpu_kernel void @usefoo() {
   ret void
 }
 
+; Recursion: A -> B -> C -> A && C -> D -> C
+
+; CHECK-LABEL: {{^}}D
+; CHECK: .set D.num_vgpr, max(71, C.num_vgpr)
+; CHECK: .set D.num_agpr, max(0, C.num_agpr)
+; CHECK: .set D.numbered_sgpr, max(71, C.numbered_sgpr)
+; CHECK: .set D.private_seg_size, 16+max(C.private_seg_size)
+; CHECK: .set D.uses_vcc, or(1, C.uses_vcc)
+; CHECK: .set D.uses_flat_scratch, or(0, C.uses_flat_scratch)
+; CHECK: .set D.has_dyn_sized_stack, or(0, C.has_dyn_sized_stack)
+; CHECK: .set D.has_recursion, or(1, C.has_recursion)
+; CHECK: .set D.has_indirect_call, or(0, C.has_indirect_call)
+
+; CHECK-LABEL: {{^}}C
+; CHECK: .set C.num_vgpr, max(42, A.num_vgpr, 71)
+; CHECK: .set C.num_agpr, max(0, A.num_agpr, 0)
+; CHECK: .set C.numbered_sgpr, max(71, A.numbered_sgpr, 71)
+; CHECK: .set C.private_seg_size, 16+max(A.private_seg_size)
+; CHECK: .set C.uses_vcc, or(1, A.uses_vcc)
+; CHECK: .set C.uses_flat_scratch, or(0, A.uses_flat_scratch)
+; CHECK: .set C.has_dyn_sized_stack, or(0, A.has_dyn_sized_stack)
+; CHECK: .set C.has_recursion, or(1, A.has_recursion)
+; CHECK: .set C.has_indirect_call, or(0, A.has_indirect_call)
+
+; CHECK-LABEL: {{^}}B
+; CHECK: .set B.num_vgpr, max(42, C.num_vgpr)
+; CHECK: .set B.num_agpr, max(0, C.num_agpr)
+; CHECK: .set B.numbered_sgpr, max(71, C.numbered_sgpr)
+; CHECK: .set B.private_seg_size, 16+max(C.private_seg_size)
+; CHECK: .set B.uses_vcc, or(1, C.uses_vcc)
+; CHECK: .set B.uses_flat_scratch, or(0, C.uses_flat_scratch)
+; CHECK: .set B.has_dyn_sized_stack, or(0, C.has_dyn_sized_stack)
+; CHECK: .set B.has_recursion, or(1, C.has_recursion)
+; CHECK: .set B.has_indirect_call, or(0, C.has_indirect_call)
+
+; CHECK-LABEL: {{^}}A
+; CHECK: .set A.num_vgpr, max(42, 71)
+; CHECK: .set A.num_agpr, max(0, 0)
+; CHECK: .set A.numbered_sgpr, max(71, 71)
+; CHECK: .set A.private_seg_size, 16
+; CHECK: .set A.uses_vcc, 1
+; CHECK: .set A.uses_flat_scratch, 0
+; CHECK: .set A.has_dyn_sized_stack, 0
+; CHECK: .set A.has_recursion, 1
+; CHECK: .set A.has_indirect_call, 0
+
+define void @A() {
+  call void @B()
+  call void asm sideeffect "", "~{v10}"()
+  call void asm sideeffect "", "~{s50}"()
+  ret void
+}
+
+define void @B() {
+  call void @C()
+  call void asm sideeffect "", "~{v20}"()
+  call void asm sideeffect "", "~{s30}"()
+  ret void
+}
+
+define void @C() {
+  call void @A()
+  call void @D()
+  call void asm sideeffect "", "~{v30}"()
+  call void asm sideeffect "", "~{s40}"()
+  ret void
+}
+
+define void @D() {
+  call void @C()
+  call void asm sideeffect "", "~{v70}"()
+  call void asm sideeffect "", "~{s70}"()
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/142766


More information about the llvm-commits mailing list