[llvm] Reapply [AMDGPU] Avoid resource propagation for recursion through multiple functions (PR #112251)

Janek van Oirschot via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 12:31:30 PST 2024


https://github.com/JanekvO updated https://github.com/llvm/llvm-project/pull/112251

>From 1b882858ed8690c392fa27088d03497f56e9c665 Mon Sep 17 00:00:00 2001
From: Janek van Oirschot <janek.vanoirschot at amd.com>
Date: Mon, 14 Oct 2024 20:33:51 +0100
Subject: [PATCH 1/5] Reapply [AMDGPU] Avoid resource propagation for recursion
 through multiple functions

I was wrong last patch. I viewed the visited set purely as a possible
recursion deterrent where functions calling a callee multiple times are
handled elsewhere. This wouldn't consider cases where a function is
called multiple times by different callers still part of the same call
graph. New test shows the aforementioned case.

Reapplies #111004
---
 .../Target/AMDGPU/AMDGPUMCResourceInfo.cpp    |  89 ++++++++++++-
 .../CodeGen/AMDGPU/function-resource-usage.ll | 126 ++++++++++++++++++
 .../multi-call-resource-usage-mcexpr.ll       |  82 ++++++++++++
 .../AMDGPU/recursive-resource-usage-mcexpr.ll |  85 ++++++++++++
 4 files changed, 375 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/CodeGen/AMDGPU/multi-call-resource-usage-mcexpr.ll
 create mode 100644 llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index da0397fa20bd1b..ee1453d1d733ba 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -91,6 +91,68 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
   return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
 }
 
+// The (partially complete) expression should have no recursion in it. After
+// all, we're trying to avoid recursion using this codepath. Returns true if
+// Sym is found within Expr without recursing on Expr, false otherwise.
+static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
+                             SmallVectorImpl<const MCExpr *> &Exprs,
+                             SmallPtrSetImpl<const MCExpr *> &Visited) {
+  // Skip duplicate visits
+  if (!Visited.insert(Expr).second)
+    return false;
+
+  switch (Expr->getKind()) {
+  default:
+    return false;
+  case MCExpr::ExprKind::SymbolRef: {
+    const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
+    const MCSymbol &SymRef = SymRefExpr->getSymbol();
+    if (Sym == &SymRef)
+      return true;
+    if (SymRef.isVariable())
+      Exprs.push_back(SymRef.getVariableValue(/*isUsed=*/false));
+    return false;
+  }
+  case MCExpr::ExprKind::Binary: {
+    const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+    Exprs.push_back(BExpr->getLHS());
+    Exprs.push_back(BExpr->getRHS());
+    return false;
+  }
+  case MCExpr::ExprKind::Unary: {
+    const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
+    Exprs.push_back(UExpr->getSubExpr());
+    return false;
+  }
+  case MCExpr::ExprKind::Target: {
+    const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
+    for (const MCExpr *E : AGVK->getArgs())
+      Exprs.push_back(E);
+    return false;
+  }
+  }
+}
+
+// Symbols whose values eventually are used through their defines (i.e.,
+// recursive) must be avoided. Do a walk over Expr to see if Sym will occur in
+// it. The Expr is an MCExpr given through a callee's equivalent MCSymbol so if
+// no recursion is found Sym can be safely assigned to a (sub-)expr which
+// contains the symbol Expr is associated with. Returns true if Sym exists
+// in Expr or its sub-expressions, false otherwise.
+static bool foundRecursiveSymbolDef(MCSymbol *Sym, const MCExpr *Expr) {
+  SmallVector<const MCExpr *, 8> WorkList;
+  SmallPtrSet<const MCExpr *, 8> Visited;
+  WorkList.push_back(Expr);
+
+  while (!WorkList.empty()) {
+    const MCExpr *CurExpr = WorkList.pop_back_val();
+    if (findSymbolInExpr(Sym, CurExpr, WorkList, Visited))
+      return true;
+  }
+
+  return false;
+}
+
 void MCResourceInfo::assignResourceInfoExpr(
     int64_t LocalValue, ResourceInfoKind RIK, AMDGPUMCExpr::VariantKind Kind,
     const MachineFunction &MF, const SmallVectorImpl<const Function *> &Callees,
@@ -98,6 +160,7 @@ void MCResourceInfo::assignResourceInfoExpr(
   const MCConstantExpr *LocalConstExpr =
       MCConstantExpr::create(LocalValue, OutContext);
   const MCExpr *SymVal = LocalConstExpr;
+  MCSymbol *Sym = getSymbol(MF.getName(), RIK, OutContext);
   if (!Callees.empty()) {
     SmallVector<const MCExpr *, 8> ArgExprs;
     // Avoid recursive symbol assignment.
@@ -110,11 +173,17 @@ void MCResourceInfo::assignResourceInfoExpr(
       if (!Seen.insert(Callee).second)
         continue;
       MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
-      ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+      bool CalleeIsVar = CalleeValSym->isVariable();
+      if (!CalleeIsVar ||
+          (CalleeIsVar &&
+           !foundRecursiveSymbolDef(
+               Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
+        ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+      }
     }
-    SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
+    if (ArgExprs.size() > 1)
+      SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
   }
-  MCSymbol *Sym = getSymbol(MF.getName(), RIK, OutContext);
   Sym->setVariableValue(SymVal);
 }
 
@@ -155,6 +224,7 @@ void MCResourceInfo::gatherResourceInfo(
     // The expression for private segment size should be: FRI.PrivateSegmentSize
     // + max(FRI.Callees, FRI.CalleeSegmentSize)
     SmallVector<const MCExpr *, 8> ArgExprs;
+    MCSymbol *Sym = getSymbol(MF.getName(), RIK_PrivateSegSize, OutContext);
     if (FRI.CalleeSegmentSize)
       ArgExprs.push_back(
           MCConstantExpr::create(FRI.CalleeSegmentSize, OutContext));
@@ -165,9 +235,15 @@ void MCResourceInfo::gatherResourceInfo(
       if (!Seen.insert(Callee).second)
         continue;
       if (!Callee->isDeclaration()) {
-        MCSymbol *calleeValSym =
+        MCSymbol *CalleeValSym =
             getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
-        ArgExprs.push_back(MCSymbolRefExpr::create(calleeValSym, OutContext));
+        bool CalleeIsVar = CalleeValSym->isVariable();
+        if (!CalleeIsVar ||
+            (CalleeIsVar &&
+             !foundRecursiveSymbolDef(
+                 Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
+          ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+        }
       }
     }
     const MCExpr *localConstExpr =
@@ -178,8 +254,7 @@ void MCResourceInfo::gatherResourceInfo(
       localConstExpr =
           MCBinaryExpr::createAdd(localConstExpr, transitiveExpr, OutContext);
     }
-    getSymbol(MF.getName(), RIK_PrivateSegSize, OutContext)
-        ->setVariableValue(localConstExpr);
+    Sym->setVariableValue(localConstExpr);
   }
 
   auto SetToLocal = [&](int64_t LocalValue, ResourceInfoKind RIK) {
diff --git a/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll b/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
index d3a6b4e01ebfb8..c8cf7d7e535b33 100644
--- a/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
+++ b/llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
@@ -481,6 +481,132 @@ define amdgpu_kernel void @usage_direct_recursion(i32 %n) #0 {
   ret void
 }
 
+; GCN-LABEL: {{^}}multi_stage_recurse2:
+; GCN: .set multi_stage_recurse2.num_vgpr, max(41, multi_stage_recurse1.num_vgpr)
+; GCN: .set multi_stage_recurse2.num_agpr, max(0, multi_stage_recurse1.num_agpr)
+; GCN: .set multi_stage_recurse2.numbered_sgpr, max(34, multi_stage_recurse1.numbered_sgpr)
+; GCN: .set multi_stage_recurse2.private_seg_size, 16+(max(multi_stage_recurse1.private_seg_size))
+; GCN: .set multi_stage_recurse2.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
+; GCN: .set multi_stage_recurse2.uses_flat_scratch, or(0, multi_stage_recurse1.uses_flat_scratch)
+; GCN: .set multi_stage_recurse2.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
+; GCN: .set multi_stage_recurse2.has_recursion, or(1, multi_stage_recurse1.has_recursion)
+; GCN: .set multi_stage_recurse2.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
+; GCN: TotalNumSgprs: multi_stage_recurse2.numbered_sgpr+(extrasgprs(multi_stage_recurse2.uses_vcc, multi_stage_recurse2.uses_flat_scratch, 1))
+; GCN: NumVgprs: max(41, multi_stage_recurse1.num_vgpr)
+; GCN: ScratchSize: 16+(max(multi_stage_recurse1.private_seg_size))
+; GCN-LABEL: {{^}}multi_stage_recurse1:
+; GCN: .set multi_stage_recurse1.num_vgpr, 41
+; GCN: .set multi_stage_recurse1.num_agpr, 0
+; GCN: .set multi_stage_recurse1.numbered_sgpr, 34
+; GCN: .set multi_stage_recurse1.private_seg_size, 16
+; GCN: .set multi_stage_recurse1.uses_vcc, 1
+; GCN: .set multi_stage_recurse1.uses_flat_scratch, 0
+; GCN: .set multi_stage_recurse1.has_dyn_sized_stack, 0
+; GCN: .set multi_stage_recurse1.has_recursion, 1
+; GCN: .set multi_stage_recurse1.has_indirect_call, 0
+; GCN: TotalNumSgprs: 38
+; GCN: NumVgprs: 41
+; GCN: ScratchSize: 16
+define void @multi_stage_recurse1(i32 %val) #2 {
+  call void @multi_stage_recurse2(i32 %val)
+  ret void
+}
+define void @multi_stage_recurse2(i32 %val) #2 {
+  call void @multi_stage_recurse1(i32 %val)
+  ret void
+}
+
+; GCN-LABEL: {{^}}usage_multi_stage_recurse:
+; GCN: .set usage_multi_stage_recurse.num_vgpr, max(32, multi_stage_recurse1.num_vgpr)
+; GCN: .set usage_multi_stage_recurse.num_agpr, max(0, multi_stage_recurse1.num_agpr)
+; GCN: .set usage_multi_stage_recurse.numbered_sgpr, max(33, multi_stage_recurse1.numbered_sgpr)
+; GCN: .set usage_multi_stage_recurse.private_seg_size, 0+(max(multi_stage_recurse1.private_seg_size))
+; GCN: .set usage_multi_stage_recurse.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
+; GCN: .set usage_multi_stage_recurse.uses_flat_scratch, or(1, multi_stage_recurse1.uses_flat_scratch)
+; GCN: .set usage_multi_stage_recurse.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
+; GCN: .set usage_multi_stage_recurse.has_recursion, or(1, multi_stage_recurse1.has_recursion)
+; GCN: .set usage_multi_stage_recurse.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
+; GCN: TotalNumSgprs: 40
+; GCN: NumVgprs: 41
+; GCN: ScratchSize: 16
+define amdgpu_kernel void @usage_multi_stage_recurse(i32 %n) #0 {
+  call void @multi_stage_recurse1(i32 %n)
+  ret void
+}
+
+; GCN-LABEL: {{^}}multi_stage_recurse_noattr2:
+; GCN: .set multi_stage_recurse_noattr2.num_vgpr, max(41, multi_stage_recurse_noattr1.num_vgpr)
+; GCN: .set multi_stage_recurse_noattr2.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
+; GCN: .set multi_stage_recurse_noattr2.numbered_sgpr, max(34, multi_stage_recurse_noattr1.numbered_sgpr)
+; GCN: .set multi_stage_recurse_noattr2.private_seg_size, 16+(max(multi_stage_recurse_noattr1.private_seg_size))
+; GCN: .set multi_stage_recurse_noattr2.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
+; GCN: .set multi_stage_recurse_noattr2.uses_flat_scratch, or(0, multi_stage_recurse_noattr1.uses_flat_scratch)
+; GCN: .set multi_stage_recurse_noattr2.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
+; GCN: .set multi_stage_recurse_noattr2.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
+; GCN: .set multi_stage_recurse_noattr2.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
+; GCN: TotalNumSgprs: multi_stage_recurse_noattr2.numbered_sgpr+(extrasgprs(multi_stage_recurse_noattr2.uses_vcc, multi_stage_recurse_noattr2.uses_flat_scratch, 1))
+; GCN: NumVgprs: max(41, multi_stage_recurse_noattr1.num_vgpr)
+; GCN: ScratchSize: 16+(max(multi_stage_recurse_noattr1.private_seg_size))
+; GCN-LABEL: {{^}}multi_stage_recurse_noattr1:
+; GCN: .set multi_stage_recurse_noattr1.num_vgpr, 41
+; GCN: .set multi_stage_recurse_noattr1.num_agpr, 0
+; GCN: .set multi_stage_recurse_noattr1.numbered_sgpr, 34
+; GCN: .set multi_stage_recurse_noattr1.private_seg_size, 16
+; GCN: .set multi_stage_recurse_noattr1.uses_vcc, 1
+; GCN: .set multi_stage_recurse_noattr1.uses_flat_scratch, 0
+; GCN: .set multi_stage_recurse_noattr1.has_dyn_sized_stack, 0
+; GCN: .set multi_stage_recurse_noattr1.has_recursion, 0
+; GCN: .set multi_stage_recurse_noattr1.has_indirect_call, 0
+; GCN: TotalNumSgprs: 38
+; GCN: NumVgprs: 41
+; GCN: ScratchSize: 16
+define void @multi_stage_recurse_noattr1(i32 %val) #0 {
+  call void @multi_stage_recurse_noattr2(i32 %val)
+  ret void
+}
+define void @multi_stage_recurse_noattr2(i32 %val) #0 {
+  call void @multi_stage_recurse_noattr1(i32 %val)
+  ret void
+}
+
+; GCN-LABEL: {{^}}usage_multi_stage_recurse_noattrs:
+; GCN: .set usage_multi_stage_recurse_noattrs.num_vgpr, max(32, multi_stage_recurse_noattr1.num_vgpr)
+; GCN: .set usage_multi_stage_recurse_noattrs.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
+; GCN: .set usage_multi_stage_recurse_noattrs.numbered_sgpr, max(33, multi_stage_recurse_noattr1.numbered_sgpr)
+; GCN: .set usage_multi_stage_recurse_noattrs.private_seg_size, 0+(max(multi_stage_recurse_noattr1.private_seg_size))
+; GCN: .set usage_multi_stage_recurse_noattrs.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
+; GCN: .set usage_multi_stage_recurse_noattrs.uses_flat_scratch, or(1, multi_stage_recurse_noattr1.uses_flat_scratch)
+; GCN: .set usage_multi_stage_recurse_noattrs.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
+; GCN: .set usage_multi_stage_recurse_noattrs.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
+; GCN: .set usage_multi_stage_recurse_noattrs.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
+; GCN: TotalNumSgprs: 40
+; GCN: NumVgprs: 41
+; GCN: ScratchSize: 16
+define amdgpu_kernel void @usage_multi_stage_recurse_noattrs(i32 %n) #0 {
+  call void @multi_stage_recurse_noattr1(i32 %n)
+  ret void
+}
+
+; GCN-LABEL: {{^}}multi_call_with_multi_stage_recurse:
+; GCN:  .set multi_call_with_multi_stage_recurse.num_vgpr, max(41, use_stack0.num_vgpr, use_stack1.num_vgpr, multi_stage_recurse1.num_vgpr)
+; GCN:  .set multi_call_with_multi_stage_recurse.num_agpr, max(0, use_stack0.num_agpr, use_stack1.num_agpr, multi_stage_recurse1.num_agpr)
+; GCN:  .set multi_call_with_multi_stage_recurse.numbered_sgpr, max(43, use_stack0.numbered_sgpr, use_stack1.numbered_sgpr, multi_stage_recurse1.numbered_sgpr)
+; GCN:  .set multi_call_with_multi_stage_recurse.private_seg_size, 0+(max(use_stack0.private_seg_size, use_stack1.private_seg_size, multi_stage_recurse1.private_seg_size))
+; GCN:  .set multi_call_with_multi_stage_recurse.uses_vcc, or(1, use_stack0.uses_vcc, use_stack1.uses_vcc, multi_stage_recurse1.uses_vcc)
+; GCN:  .set multi_call_with_multi_stage_recurse.uses_flat_scratch, or(1, use_stack0.uses_flat_scratch, use_stack1.uses_flat_scratch, multi_stage_recurse1.uses_flat_scratch)
+; GCN:  .set multi_call_with_multi_stage_recurse.has_dyn_sized_stack, or(0, use_stack0.has_dyn_sized_stack, use_stack1.has_dyn_sized_stack, multi_stage_recurse1.has_dyn_sized_stack)
+; GCN:  .set multi_call_with_multi_stage_recurse.has_recursion, or(1, use_stack0.has_recursion, use_stack1.has_recursion, multi_stage_recurse1.has_recursion)
+; GCN:  .set multi_call_with_multi_stage_recurse.has_indirect_call, or(0, use_stack0.has_indirect_call, use_stack1.has_indirect_call, multi_stage_recurse1.has_indirect_call)
+; GCN: TotalNumSgprs: 49
+; GCN: NumVgprs: 41
+; GCN: ScratchSize: 2052
+define amdgpu_kernel void @multi_call_with_multi_stage_recurse(i32 %n) #0 {
+  call void @use_stack0()
+  call void @use_stack1()
+  call void @multi_stage_recurse1(i32 %n)
+  ret void
+}
+
 ; Make sure there's no assert when a sgpr96 is used.
 ; GCN-LABEL: {{^}}count_use_sgpr96_external_call
 ; GCN:	.set count_use_sgpr96_external_call.num_vgpr, max(32, amdgpu.max_num_vgpr)
diff --git a/llvm/test/CodeGen/AMDGPU/multi-call-resource-usage-mcexpr.ll b/llvm/test/CodeGen/AMDGPU/multi-call-resource-usage-mcexpr.ll
new file mode 100644
index 00000000000000..c04bc96828e1f1
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/multi-call-resource-usage-mcexpr.ll
@@ -0,0 +1,82 @@
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck %s
+
+; CHECK-LABEL: {{^}}qux
+; CHECK: .set qux.num_vgpr, 0
+; CHECK: .set qux.num_agpr, 0
+; CHECK: .set qux.numbered_sgpr, 32
+; CHECK: .set qux.private_seg_size, 0
+; CHECK: .set qux.uses_vcc, 0
+; CHECK: .set qux.uses_flat_scratch, 0
+; CHECK: .set qux.has_dyn_sized_stack, 0
+; CHECK: .set qux.has_recursion, 0
+; CHECK: .set qux.has_indirect_call, 0
+define void @qux() {
+entry:
+  ret void
+}
+
+; CHECK-LABEL: {{^}}baz
+; CHECK: .set baz.num_vgpr, max(32, qux.num_vgpr)
+; CHECK: .set baz.num_agpr, max(0, qux.num_agpr)
+; CHECK: .set baz.numbered_sgpr, max(34, qux.numbered_sgpr)
+; CHECK: .set baz.private_seg_size, 16+(max(qux.private_seg_size))
+; CHECK: .set baz.uses_vcc, or(0, qux.uses_vcc)
+; CHECK: .set baz.uses_flat_scratch, or(0, qux.uses_flat_scratch)
+; CHECK: .set baz.has_dyn_sized_stack, or(0, qux.has_dyn_sized_stack)
+; CHECK: .set baz.has_recursion, or(1, qux.has_recursion)
+; CHECK: .set baz.has_indirect_call, or(0, qux.has_indirect_call)
+define void @baz() {
+entry:
+  call void @qux()
+  ret void
+}
+
+; CHECK-LABEL: {{^}}bar
+; CHECK: .set bar.num_vgpr, max(32, baz.num_vgpr, qux.num_vgpr)
+; CHECK: .set bar.num_agpr, max(0, baz.num_agpr, qux.num_agpr)
+; CHECK: .set bar.numbered_sgpr, max(34, baz.numbered_sgpr, qux.numbered_sgpr)
+; CHECK: .set bar.private_seg_size, 16+(max(baz.private_seg_size, qux.private_seg_size))
+; CHECK: .set bar.uses_vcc, or(0, baz.uses_vcc, qux.uses_vcc)
+; CHECK: .set bar.uses_flat_scratch, or(0, baz.uses_flat_scratch, qux.uses_flat_scratch)
+; CHECK: .set bar.has_dyn_sized_stack, or(0, baz.has_dyn_sized_stack, qux.has_dyn_sized_stack)
+; CHECK: .set bar.has_recursion, or(1, baz.has_recursion, qux.has_recursion)
+; CHECK: .set bar.has_indirect_call, or(0, baz.has_indirect_call, qux.has_indirect_call)
+define void @bar() {
+entry:
+  call void @baz()
+  call void @qux()
+  call void @baz()
+  ret void
+}
+
+; CHECK-LABEL: {{^}}foo
+; CHECK: .set foo.num_vgpr, max(32, bar.num_vgpr)
+; CHECK: .set foo.num_agpr, max(0, bar.num_agpr)
+; CHECK: .set foo.numbered_sgpr, max(34, bar.numbered_sgpr)
+; CHECK: .set foo.private_seg_size, 16+(max(bar.private_seg_size))
+; CHECK: .set foo.uses_vcc, or(0, bar.uses_vcc)
+; CHECK: .set foo.uses_flat_scratch, or(0, bar.uses_flat_scratch)
+; CHECK: .set foo.has_dyn_sized_stack, or(0, bar.has_dyn_sized_stack)
+; CHECK: .set foo.has_recursion, or(1, bar.has_recursion)
+; CHECK: .set foo.has_indirect_call, or(0, bar.has_indirect_call)
+define void @foo() {
+entry:
+  call void @bar()
+  ret void
+}
+
+; CHECK-LABEL: {{^}}usefoo
+; CHECK: .set usefoo.num_vgpr, max(32, foo.num_vgpr)
+; CHECK: .set usefoo.num_agpr, max(0, foo.num_agpr)
+; CHECK: .set usefoo.numbered_sgpr, max(33, foo.numbered_sgpr)
+; CHECK: .set usefoo.private_seg_size, 0+(max(foo.private_seg_size))
+; CHECK: .set usefoo.uses_vcc, or(0, foo.uses_vcc)
+; CHECK: .set usefoo.uses_flat_scratch, or(1, foo.uses_flat_scratch)
+; CHECK: .set usefoo.has_dyn_sized_stack, or(0, foo.has_dyn_sized_stack)
+; CHECK: .set usefoo.has_recursion, or(1, foo.has_recursion)
+; CHECK: .set usefoo.has_indirect_call, or(0, foo.has_indirect_call)
+define amdgpu_kernel void @usefoo() {
+  call void @foo()
+  ret void
+}
+
diff --git a/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll b/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
new file mode 100644
index 00000000000000..7e1090afc0cf1a
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
@@ -0,0 +1,85 @@
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck %s
+
+; CHECK-LABEL: {{^}}qux
+; CHECK: .set qux.num_vgpr, max(41, foo.num_vgpr)
+; CHECK: .set qux.num_agpr, max(0, foo.num_agpr)
+; CHECK: .set qux.numbered_sgpr, max(34, foo.numbered_sgpr)
+; CHECK: .set qux.private_seg_size, 16
+; CHECK: .set qux.uses_vcc, or(1, foo.uses_vcc)
+; CHECK: .set qux.uses_flat_scratch, or(0, foo.uses_flat_scratch)
+; CHECK: .set qux.has_dyn_sized_stack, or(0, foo.has_dyn_sized_stack)
+; CHECK: .set qux.has_recursion, or(1, foo.has_recursion)
+; CHECK: .set qux.has_indirect_call, or(0, foo.has_indirect_call)
+
+; CHECK-LABEL: {{^}}baz
+; CHECK: .set baz.num_vgpr, max(42, qux.num_vgpr)
+; CHECK: .set baz.num_agpr, max(0, qux.num_agpr)
+; CHECK: .set baz.numbered_sgpr, max(34, qux.numbered_sgpr)
+; CHECK: .set baz.private_seg_size, 16+(max(qux.private_seg_size))
+; CHECK: .set baz.uses_vcc, or(1, qux.uses_vcc)
+; CHECK: .set baz.uses_flat_scratch, or(0, qux.uses_flat_scratch)
+; CHECK: .set baz.has_dyn_sized_stack, or(0, qux.has_dyn_sized_stack)
+; CHECK: .set baz.has_recursion, or(1, qux.has_recursion)
+; CHECK: .set baz.has_indirect_call, or(0, qux.has_indirect_call)
+
+; CHECK-LABEL: {{^}}bar
+; CHECK: .set bar.num_vgpr, max(42, baz.num_vgpr)
+; CHECK: .set bar.num_agpr, max(0, baz.num_agpr)
+; CHECK: .set bar.numbered_sgpr, max(34, baz.numbered_sgpr)
+; CHECK: .set bar.private_seg_size, 16+(max(baz.private_seg_size))
+; CHECK: .set bar.uses_vcc, or(1, baz.uses_vcc)
+; CHECK: .set bar.uses_flat_scratch, or(0, baz.uses_flat_scratch)
+; CHECK: .set bar.has_dyn_sized_stack, or(0, baz.has_dyn_sized_stack)
+; CHECK: .set bar.has_recursion, or(1, baz.has_recursion)
+; CHECK: .set bar.has_indirect_call, or(0, baz.has_indirect_call)
+
+; CHECK-LABEL: {{^}}foo
+; CHECK: .set foo.num_vgpr, 42
+; CHECK: .set foo.num_agpr, 0
+; CHECK: .set foo.numbered_sgpr, 34
+; CHECK: .set foo.private_seg_size, 16
+; CHECK: .set foo.uses_vcc, 1
+; CHECK: .set foo.uses_flat_scratch, 0
+; CHECK: .set foo.has_dyn_sized_stack, 0
+; CHECK: .set foo.has_recursion, 1
+; CHECK: .set foo.has_indirect_call, 0
+
+define void @foo() {
+entry:
+  call void @bar()
+  ret void
+}
+
+define void @bar() {
+entry:
+  call void @baz()
+  ret void
+}
+
+define void @baz() {
+entry:
+  call void @qux()
+  ret void
+}
+
+define void @qux() {
+entry:
+  call void @foo()
+  ret void
+}
+
+; CHECK-LABEL: {{^}}usefoo
+; CHECK: .set usefoo.num_vgpr, max(32, foo.num_vgpr)
+; CHECK: .set usefoo.num_agpr, max(0, foo.num_agpr)
+; CHECK: .set usefoo.numbered_sgpr, max(33, foo.numbered_sgpr)
+; CHECK: .set usefoo.private_seg_size, 0+(max(foo.private_seg_size))
+; CHECK: .set usefoo.uses_vcc, or(1, foo.uses_vcc)
+; CHECK: .set usefoo.uses_flat_scratch, or(1, foo.uses_flat_scratch)
+; CHECK: .set usefoo.has_dyn_sized_stack, or(0, foo.has_dyn_sized_stack)
+; CHECK: .set usefoo.has_recursion, or(1, foo.has_recursion)
+; CHECK: .set usefoo.has_indirect_call, or(0, foo.has_indirect_call)
+define amdgpu_kernel void @usefoo() {
+  call void @foo()
+  ret void
+}
+

>From 546397105d7dbd780130274cbc08e5742dd4b277 Mon Sep 17 00:00:00 2001
From: Janek van Oirschot <janek.vanoirschot at amd.com>
Date: Tue, 15 Oct 2024 16:50:59 +0100
Subject: [PATCH 2/5] Feedback, merge WorkList iteration in Callee iteration

---
 .../Target/AMDGPU/AMDGPUMCResourceInfo.cpp    | 87 +++++++++----------
 1 file changed, 40 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index ee1453d1d733ba..5709fecf4dec25 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -91,16 +91,12 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
   return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
 }
 
-// The (partially complete) expression should have no recursion in it. After
-// all, we're trying to avoid recursion using this codepath. Returns true if
-// Sym is found within Expr without recursing on Expr, false otherwise.
+// The expression should have no recursion in it. Test a (sub-)expression to see
+// if it needs to be further visited, or if a recursion has been found. Returns
+// true if Sym is found within Expr (i.e., has a recurrance of Sym found), false
+// otherwise.
 static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
-                             SmallVectorImpl<const MCExpr *> &Exprs,
-                             SmallPtrSetImpl<const MCExpr *> &Visited) {
-  // Skip duplicate visits
-  if (!Visited.insert(Expr).second)
-    return false;
-
+                             SmallPtrSetImpl<const MCExpr *> &Exprs) {
   switch (Expr->getKind()) {
   default:
     return false;
@@ -110,49 +106,29 @@ static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
     if (Sym == &SymRef)
       return true;
     if (SymRef.isVariable())
-      Exprs.push_back(SymRef.getVariableValue(/*isUsed=*/false));
+      Exprs.insert(SymRef.getVariableValue(/*isUsed=*/false));
     return false;
   }
   case MCExpr::ExprKind::Binary: {
     const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
-    Exprs.push_back(BExpr->getLHS());
-    Exprs.push_back(BExpr->getRHS());
+    Exprs.insert(BExpr->getLHS());
+    Exprs.insert(BExpr->getRHS());
     return false;
   }
   case MCExpr::ExprKind::Unary: {
     const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
-    Exprs.push_back(UExpr->getSubExpr());
+    Exprs.insert(UExpr->getSubExpr());
     return false;
   }
   case MCExpr::ExprKind::Target: {
     const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
     for (const MCExpr *E : AGVK->getArgs())
-      Exprs.push_back(E);
+      Exprs.insert(E);
     return false;
   }
   }
 }
 
-// Symbols whose values eventually are used through their defines (i.e.,
-// recursive) must be avoided. Do a walk over Expr to see if Sym will occur in
-// it. The Expr is an MCExpr given through a callee's equivalent MCSymbol so if
-// no recursion is found Sym can be safely assigned to a (sub-)expr which
-// contains the symbol Expr is associated with. Returns true if Sym exists
-// in Expr or its sub-expressions, false otherwise.
-static bool foundRecursiveSymbolDef(MCSymbol *Sym, const MCExpr *Expr) {
-  SmallVector<const MCExpr *, 8> WorkList;
-  SmallPtrSet<const MCExpr *, 8> Visited;
-  WorkList.push_back(Expr);
-
-  while (!WorkList.empty()) {
-    const MCExpr *CurExpr = WorkList.pop_back_val();
-    if (findSymbolInExpr(Sym, CurExpr, WorkList, Visited))
-      return true;
-  }
-
-  return false;
-}
-
 void MCResourceInfo::assignResourceInfoExpr(
     int64_t LocalValue, ResourceInfoKind RIK, AMDGPUMCExpr::VariantKind Kind,
     const MachineFunction &MF, const SmallVectorImpl<const Function *> &Callees,
@@ -163,23 +139,31 @@ void MCResourceInfo::assignResourceInfoExpr(
   MCSymbol *Sym = getSymbol(MF.getName(), RIK, OutContext);
   if (!Callees.empty()) {
     SmallVector<const MCExpr *, 8> ArgExprs;
-    // Avoid recursive symbol assignment.
     SmallPtrSet<const Function *, 8> Seen;
     ArgExprs.push_back(LocalConstExpr);
-    const Function &F = MF.getFunction();
-    Seen.insert(&F);
 
     for (const Function *Callee : Callees) {
       if (!Seen.insert(Callee).second)
         continue;
+
+      SmallPtrSet<const MCExpr *, 8> WorkSet;
       MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
-      bool CalleeIsVar = CalleeValSym->isVariable();
-      if (!CalleeIsVar ||
-          (CalleeIsVar &&
-           !foundRecursiveSymbolDef(
-               Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
+      if (CalleeValSym->isVariable())
+        WorkSet.insert(CalleeValSym->getVariableValue(/*IsUsed=*/false));
+      else
         ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+
+      bool FoundRecursion = false;
+      while (!WorkSet.empty() && !FoundRecursion) {
+        auto It = WorkSet.begin();
+        const MCExpr *Expr = *It;
+        WorkSet.erase(Expr);
+
+        FoundRecursion = findSymbolInExpr(Sym, Expr, WorkSet);
       }
+
+      if (CalleeValSym->isVariable() && !FoundRecursion)
+        ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
     }
     if (ArgExprs.size() > 1)
       SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
@@ -235,15 +219,24 @@ void MCResourceInfo::gatherResourceInfo(
       if (!Seen.insert(Callee).second)
         continue;
       if (!Callee->isDeclaration()) {
+        SmallPtrSet<const MCExpr *, 8> WorkSet;
         MCSymbol *CalleeValSym =
             getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
-        bool CalleeIsVar = CalleeValSym->isVariable();
-        if (!CalleeIsVar ||
-            (CalleeIsVar &&
-             !foundRecursiveSymbolDef(
-                 Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
+        if (CalleeValSym->isVariable())
+          WorkSet.insert(CalleeValSym->getVariableValue(/*IsUsed=*/false));
+        else
           ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+
+        bool FoundRecursion = false;
+        while (!WorkSet.empty() && !FoundRecursion) {
+          auto It = WorkSet.begin();
+          const MCExpr *Expr = *It;
+          WorkSet.erase(Expr);
+
+          FoundRecursion = findSymbolInExpr(Sym, Expr, WorkSet);
         }
+        if (CalleeValSym->isVariable() && !FoundRecursion)
+          ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
       }
     }
     const MCExpr *localConstExpr =

>From ba27cd99e9518f24c57327b0e8cb42a5c2da7140 Mon Sep 17 00:00:00 2001
From: Janek van Oirschot <janek.vanoirschot at amd.com>
Date: Wed, 23 Oct 2024 20:57:11 +0100
Subject: [PATCH 3/5] Recursive walk instead of iterative over a WorkSet

---
 .../Target/AMDGPU/AMDGPUMCResourceInfo.cpp    | 81 ++++++++++---------
 1 file changed, 41 insertions(+), 40 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index 5709fecf4dec25..f047f9c6e5aa1b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -96,34 +96,48 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
 // true if Sym is found within Expr (i.e., has a recurrance of Sym found), false
 // otherwise.
 static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
-                             SmallPtrSetImpl<const MCExpr *> &Exprs) {
+                             SmallPtrSetImpl<const MCExpr *> &Visited) {
+
+  if (Expr->getKind() == MCExpr::ExprKind::SymbolRef) {
+    const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
+    const MCSymbol &SymRef = SymRefExpr->getSymbol();
+    if (Sym == &SymRef)
+      return true;
+  }
+
+  if (!Visited.insert(Expr).second)
+    return false;
+
   switch (Expr->getKind()) {
   default:
     return false;
   case MCExpr::ExprKind::SymbolRef: {
     const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
     const MCSymbol &SymRef = SymRefExpr->getSymbol();
-    if (Sym == &SymRef)
-      return true;
-    if (SymRef.isVariable())
-      Exprs.insert(SymRef.getVariableValue(/*isUsed=*/false));
+    if (SymRef.isVariable()) {
+      return findSymbolInExpr(Sym, SymRef.getVariableValue(/*isUsed=*/false),
+                              Visited);
+    }
     return false;
   }
   case MCExpr::ExprKind::Binary: {
     const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
-    Exprs.insert(BExpr->getLHS());
-    Exprs.insert(BExpr->getRHS());
+    if (findSymbolInExpr(Sym, BExpr->getLHS(), Visited) ||
+        findSymbolInExpr(Sym, BExpr->getRHS(), Visited)) {
+      return true;
+    }
     return false;
   }
   case MCExpr::ExprKind::Unary: {
     const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
-    Exprs.insert(UExpr->getSubExpr());
-    return false;
+    return findSymbolInExpr(Sym, UExpr->getSubExpr(), Visited);
   }
   case MCExpr::ExprKind::Target: {
     const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
-    for (const MCExpr *E : AGVK->getArgs())
-      Exprs.insert(E);
+    for (const MCExpr *E : AGVK->getArgs()) {
+      if (findSymbolInExpr(Sym, E, Visited))
+        return true;
+    }
     return false;
   }
   }
@@ -146,24 +160,17 @@ void MCResourceInfo::assignResourceInfoExpr(
       if (!Seen.insert(Callee).second)
         continue;
 
-      SmallPtrSet<const MCExpr *, 8> WorkSet;
+      SmallPtrSet<const MCExpr *, 8> Visited;
       MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
-      if (CalleeValSym->isVariable())
-        WorkSet.insert(CalleeValSym->getVariableValue(/*IsUsed=*/false));
-      else
-        ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
-
-      bool FoundRecursion = false;
-      while (!WorkSet.empty() && !FoundRecursion) {
-        auto It = WorkSet.begin();
-        const MCExpr *Expr = *It;
-        WorkSet.erase(Expr);
-
-        FoundRecursion = findSymbolInExpr(Sym, Expr, WorkSet);
-      }
+      bool CalleeIsVar = CalleeValSym->isVariable();
 
-      if (CalleeValSym->isVariable() && !FoundRecursion)
+      if (!CalleeIsVar ||
+          (CalleeIsVar &&
+           !findSymbolInExpr(Sym,
+                             CalleeValSym->getVariableValue(/*IsUsed=*/false),
+                             Visited))) {
         ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+      }
     }
     if (ArgExprs.size() > 1)
       SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
@@ -219,24 +226,18 @@ void MCResourceInfo::gatherResourceInfo(
       if (!Seen.insert(Callee).second)
         continue;
       if (!Callee->isDeclaration()) {
-        SmallPtrSet<const MCExpr *, 8> WorkSet;
+        SmallPtrSet<const MCExpr *, 8> Visited;
         MCSymbol *CalleeValSym =
             getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
-        if (CalleeValSym->isVariable())
-          WorkSet.insert(CalleeValSym->getVariableValue(/*IsUsed=*/false));
-        else
-          ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+        bool CalleeIsVar = CalleeValSym->isVariable();
 
-        bool FoundRecursion = false;
-        while (!WorkSet.empty() && !FoundRecursion) {
-          auto It = WorkSet.begin();
-          const MCExpr *Expr = *It;
-          WorkSet.erase(Expr);
-
-          FoundRecursion = findSymbolInExpr(Sym, Expr, WorkSet);
-        }
-        if (CalleeValSym->isVariable() && !FoundRecursion)
+        if (!CalleeIsVar ||
+            (CalleeIsVar &&
+             !findSymbolInExpr(Sym,
+                               CalleeValSym->getVariableValue(/*IsUsed=*/false),
+                               Visited))) {
           ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
+        }
       }
     }
     const MCExpr *localConstExpr =

>From f99b85938e0c08e0c5ffc12dce8e6f3f50fc7181 Mon Sep 17 00:00:00 2001
From: Janek van Oirschot <janek.vanoirschot at amd.com>
Date: Thu, 31 Oct 2024 13:36:57 +0000
Subject: [PATCH 4/5] Feedback: remove redundant check

---
 .../Target/AMDGPU/AMDGPUMCResourceInfo.cpp    | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index f047f9c6e5aa1b..883a1645293b85 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -162,13 +162,10 @@ void MCResourceInfo::assignResourceInfoExpr(
 
       SmallPtrSet<const MCExpr *, 8> Visited;
       MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
-      bool CalleeIsVar = CalleeValSym->isVariable();
 
-      if (!CalleeIsVar ||
-          (CalleeIsVar &&
-           !findSymbolInExpr(Sym,
-                             CalleeValSym->getVariableValue(/*IsUsed=*/false),
-                             Visited))) {
+      if (!CalleeValSym->isVariable() ||
+          !findSymbolInExpr(
+              Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false), Visited)) {
         ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
       }
     }
@@ -229,13 +226,11 @@ void MCResourceInfo::gatherResourceInfo(
         SmallPtrSet<const MCExpr *, 8> Visited;
         MCSymbol *CalleeValSym =
             getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
-        bool CalleeIsVar = CalleeValSym->isVariable();
 
-        if (!CalleeIsVar ||
-            (CalleeIsVar &&
-             !findSymbolInExpr(Sym,
-                               CalleeValSym->getVariableValue(/*IsUsed=*/false),
-                               Visited))) {
+        if (!CalleeValSym->isVariable() ||
+            !findSymbolInExpr(Sym,
+                              CalleeValSym->getVariableValue(/*IsUsed=*/false),
+                              Visited)) {
           ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
         }
       }

>From b18ed55b50f917e2a1646eb1c43cb5b0b149134c Mon Sep 17 00:00:00 2001
From: Janek van Oirschot <janek.vanoirschot at amd.com>
Date: Fri, 8 Nov 2024 20:29:18 +0000
Subject: [PATCH 5/5] Move isSymbolUsedInExpression to MCExpr, use for
 recursion detection and add MCTargetExpr specific subexpr considerations for
 isSymbolUsedInExpression

---
 llvm/include/llvm/MC/MCExpr.h                 |  7 ++
 llvm/lib/MC/MCExpr.cpp                        | 29 +++++++++
 llvm/lib/MC/MCParser/AsmParser.cpp            | 29 +--------
 .../Target/AMDGPU/AMDGPUMCResourceInfo.cpp    | 64 ++-----------------
 .../AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp      |  8 +++
 .../Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h |  1 +
 6 files changed, 50 insertions(+), 88 deletions(-)

diff --git a/llvm/include/llvm/MC/MCExpr.h b/llvm/include/llvm/MC/MCExpr.h
index 10bc6ebd6fe506..ece51ebecdd9c1 100644
--- a/llvm/include/llvm/MC/MCExpr.h
+++ b/llvm/include/llvm/MC/MCExpr.h
@@ -86,6 +86,10 @@ class MCExpr {
              bool InParens = false) const;
   void dump() const;
 
+  /// Returns whether the given symbol is used anywhere in the expression or
+  /// subexpressions.
+  bool isSymbolUsedInExpression(const MCSymbol *Sym) const;
+
   /// @}
   /// \name Expression Evaluation
   /// @{
@@ -663,6 +667,9 @@ class MCTargetExpr : public MCExpr {
                                          const MCFixup *Fixup) const = 0;
   // allow Target Expressions to be checked for equality
   virtual bool isEqualTo(const MCExpr *x) const { return false; }
+  virtual bool isSymbolUsedInExpression(const MCSymbol *Sym) const {
+    return false;
+  }
   // This should be set when assigned expressions are not valid ".set"
   // expressions, e.g. registers, and must be inlined.
   virtual bool inlineAssignedExpr() const { return false; }
diff --git a/llvm/lib/MC/MCExpr.cpp b/llvm/lib/MC/MCExpr.cpp
index c9d5f6580fda4c..ede7655733f253 100644
--- a/llvm/lib/MC/MCExpr.cpp
+++ b/llvm/lib/MC/MCExpr.cpp
@@ -177,6 +177,35 @@ LLVM_DUMP_METHOD void MCExpr::dump() const {
 }
 #endif
 
+bool MCExpr::isSymbolUsedInExpression(const MCSymbol *Sym) const {
+  switch (getKind()) {
+  case MCExpr::Binary: {
+    const MCBinaryExpr *BE = static_cast<const MCBinaryExpr *>(this);
+    return BE->getLHS()->isSymbolUsedInExpression(Sym) ||
+           BE->getRHS()->isSymbolUsedInExpression(Sym);
+  }
+  case MCExpr::Target: {
+    const MCTargetExpr *TE = static_cast<const MCTargetExpr *>(this);
+    return TE->isSymbolUsedInExpression(Sym);
+  }
+  case MCExpr::Constant:
+    return false;
+  case MCExpr::SymbolRef: {
+    const MCSymbol &S = static_cast<const MCSymbolRefExpr *>(this)->getSymbol();
+    if (S.isVariable() && !S.isWeakExternal())
+      return S.getVariableValue()->isSymbolUsedInExpression(Sym);
+    return &S == Sym;
+  }
+  case MCExpr::Unary: {
+    const MCExpr *SubExpr =
+        static_cast<const MCUnaryExpr *>(this)->getSubExpr();
+    return SubExpr->isSymbolUsedInExpression(Sym);
+  }
+  }
+
+  llvm_unreachable("Unknown expr kind!");
+}
+
 /* *** */
 
 const MCBinaryExpr *MCBinaryExpr::create(Opcode Opc, const MCExpr *LHS,
diff --git a/llvm/lib/MC/MCParser/AsmParser.cpp b/llvm/lib/MC/MCParser/AsmParser.cpp
index 3f55d8a66bc2ce..9b5eb96f9884da 100644
--- a/llvm/lib/MC/MCParser/AsmParser.cpp
+++ b/llvm/lib/MC/MCParser/AsmParser.cpp
@@ -6394,33 +6394,6 @@ bool HLASMAsmParser::parseStatement(ParseStatementInfo &Info,
 namespace llvm {
 namespace MCParserUtils {
 
-/// Returns whether the given symbol is used anywhere in the given expression,
-/// or subexpressions.
-static bool isSymbolUsedInExpression(const MCSymbol *Sym, const MCExpr *Value) {
-  switch (Value->getKind()) {
-  case MCExpr::Binary: {
-    const MCBinaryExpr *BE = static_cast<const MCBinaryExpr *>(Value);
-    return isSymbolUsedInExpression(Sym, BE->getLHS()) ||
-           isSymbolUsedInExpression(Sym, BE->getRHS());
-  }
-  case MCExpr::Target:
-  case MCExpr::Constant:
-    return false;
-  case MCExpr::SymbolRef: {
-    const MCSymbol &S =
-        static_cast<const MCSymbolRefExpr *>(Value)->getSymbol();
-    if (S.isVariable() && !S.isWeakExternal())
-      return isSymbolUsedInExpression(Sym, S.getVariableValue());
-    return &S == Sym;
-  }
-  case MCExpr::Unary:
-    return isSymbolUsedInExpression(
-        Sym, static_cast<const MCUnaryExpr *>(Value)->getSubExpr());
-  }
-
-  llvm_unreachable("Unknown expr kind!");
-}
-
 bool parseAssignmentExpression(StringRef Name, bool allow_redef,
                                MCAsmParser &Parser, MCSymbol *&Sym,
                                const MCExpr *&Value) {
@@ -6445,7 +6418,7 @@ bool parseAssignmentExpression(StringRef Name, bool allow_redef,
     //
     // FIXME: Diagnostics. Note the location of the definition as a label.
     // FIXME: Diagnose assignment to protected identifier (e.g., register name).
-    if (isSymbolUsedInExpression(Sym, Value))
+    if (Value->isSymbolUsedInExpression(Sym))
       return Parser.Error(EqualLoc, "Recursive use of '" + Name + "'");
     else if (Sym->isUndefined(/*SetUsed*/ false) && !Sym->isUsed() &&
              !Sym->isVariable())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
index 883a1645293b85..ad257fbe426f89 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
@@ -91,58 +91,6 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
   return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
 }
 
-// The expression should have no recursion in it. Test a (sub-)expression to see
-// if it needs to be further visited, or if a recursion has been found. Returns
-// true if Sym is found within Expr (i.e., has a recurrance of Sym found), false
-// otherwise.
-static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
-                             SmallPtrSetImpl<const MCExpr *> &Visited) {
-
-  if (Expr->getKind() == MCExpr::ExprKind::SymbolRef) {
-    const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
-    const MCSymbol &SymRef = SymRefExpr->getSymbol();
-    if (Sym == &SymRef)
-      return true;
-  }
-
-  if (!Visited.insert(Expr).second)
-    return false;
-
-  switch (Expr->getKind()) {
-  default:
-    return false;
-  case MCExpr::ExprKind::SymbolRef: {
-    const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
-    const MCSymbol &SymRef = SymRefExpr->getSymbol();
-    if (SymRef.isVariable()) {
-      return findSymbolInExpr(Sym, SymRef.getVariableValue(/*isUsed=*/false),
-                              Visited);
-    }
-    return false;
-  }
-  case MCExpr::ExprKind::Binary: {
-    const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
-    if (findSymbolInExpr(Sym, BExpr->getLHS(), Visited) ||
-        findSymbolInExpr(Sym, BExpr->getRHS(), Visited)) {
-      return true;
-    }
-    return false;
-  }
-  case MCExpr::ExprKind::Unary: {
-    const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
-    return findSymbolInExpr(Sym, UExpr->getSubExpr(), Visited);
-  }
-  case MCExpr::ExprKind::Target: {
-    const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
-    for (const MCExpr *E : AGVK->getArgs()) {
-      if (findSymbolInExpr(Sym, E, Visited))
-        return true;
-    }
-    return false;
-  }
-  }
-}
-
 void MCResourceInfo::assignResourceInfoExpr(
     int64_t LocalValue, ResourceInfoKind RIK, AMDGPUMCExpr::VariantKind Kind,
     const MachineFunction &MF, const SmallVectorImpl<const Function *> &Callees,
@@ -160,12 +108,10 @@ void MCResourceInfo::assignResourceInfoExpr(
       if (!Seen.insert(Callee).second)
         continue;
 
-      SmallPtrSet<const MCExpr *, 8> Visited;
       MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
-
       if (!CalleeValSym->isVariable() ||
-          !findSymbolInExpr(
-              Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false), Visited)) {
+          !CalleeValSym->getVariableValue(/*isUsed=*/false)
+               ->isSymbolUsedInExpression(Sym)) {
         ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
       }
     }
@@ -223,14 +169,12 @@ void MCResourceInfo::gatherResourceInfo(
       if (!Seen.insert(Callee).second)
         continue;
       if (!Callee->isDeclaration()) {
-        SmallPtrSet<const MCExpr *, 8> Visited;
         MCSymbol *CalleeValSym =
             getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
 
         if (!CalleeValSym->isVariable() ||
-            !findSymbolInExpr(Sym,
-                              CalleeValSym->getVariableValue(/*IsUsed=*/false),
-                              Visited)) {
+            !CalleeValSym->getVariableValue(/*isUsed=*/false)
+                 ->isSymbolUsedInExpression(Sym)) {
           ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
         }
       }
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp
index d1212ec76f9860..91ff4148bb6329 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp
@@ -306,6 +306,14 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
                 Ctx);
 }
 
+bool AMDGPUMCExpr::isSymbolUsedInExpression(const MCSymbol *Sym) const {
+  for (const MCExpr *E : getArgs()) {
+    if (E->isSymbolUsedInExpression(Sym))
+      return true;
+  }
+  return false;
+}
+
 static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
   static constexpr unsigned BitWidth = 64;
   const APInt True(BitWidth, 1);
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h
index a16843f404b8f6..75e676bb7d5081 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h
@@ -97,6 +97,7 @@ class AMDGPUMCExpr : public MCTargetExpr {
   void printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const override;
   bool evaluateAsRelocatableImpl(MCValue &Res, const MCAssembler *Asm,
                                  const MCFixup *Fixup) const override;
+  bool isSymbolUsedInExpression(const MCSymbol *Sym) const override;
   void visitUsedExpr(MCStreamer &Streamer) const override;
   MCFragment *findAssociatedFragment() const override;
   void fixELFSymbolsInTLSFixups(MCAssembler &) const override{};



More information about the llvm-commits mailing list