[llvm] [FuncSpec] Update function specialization to handle phi-chains (PR #71442)

Mats Petersson via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 16 08:09:19 PST 2023


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/71442

>From 966954ca7bf9cdfeb1307fe0ddbd044300bd3680 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 6 Nov 2023 19:14:53 +0000
Subject: [PATCH 1/8] [FuncSpec] Update function specialization to handle
 phi-chains

When using the LLVM flang compiler with alias analysis (AA) enabled,
SPEC2017:548.exchange2_r was running significantly slower than
wihtout the AA.

This was caused by the GVN pass replacing many of the loads in the
pre-AA code with phi-nodes that form a long chain of dependencies,
which the function specialization was unable to follow.

This adds a function to follow phi-nodes when they are a strongly
connected component, with some limitations to avoid spending ages
analysing phi-nodes.

The minimum latency savings also had to be lowered - fewer load
instructions means less saving.

Adding some more prints to help debugging the isProfitable decision.

No significant change in compile time or generated code-size.

Co-authored-by: Alexandros Lamprineas <alexandros.lamprineas at arm.com>
---
 .../Transforms/IPO/FunctionSpecialization.h   |   4 +
 .../Transforms/IPO/FunctionSpecialization.cpp | 133 ++++++++++++++----
 .../discover-strongly-connected-phis.ll       |  87 ++++++++++++
 3 files changed, 198 insertions(+), 26 deletions(-)
 create mode 100644 llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index 50f9aae73dc53e2..f35543cb8411b35 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -183,6 +183,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   DenseSet<BasicBlock *> DeadBlocks;
   // PHI nodes we have visited before.
   DenseSet<Instruction *> VisitedPHIs;
+  // PHI nodes forming a strongly connected component.
+  DenseSet<PHINode *> StronglyConnectedPHIs;
   // PHI nodes we have visited once without successfully constant folding them.
   // Once the InstCostVisitor has processed all the specialization arguments,
   // it should be possible to determine whether those PHIs can be folded
@@ -217,6 +219,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   Cost estimateSwitchInst(SwitchInst &I);
   Cost estimateBranchInst(BranchInst &I);
 
+  void discoverStronglyConnectedComponent(PHINode *PN, unsigned Depth);
+
   Constant *visitInstruction(Instruction &I) { return nullptr; }
   Constant *visitPHINode(PHINode &I);
   Constant *visitFreezeInst(FreezeInst &I);
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index b75ca7761a60b62..23e665a1901b5e1 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -39,10 +39,15 @@ static cl::opt<unsigned> MaxClones(
     "The maximum number of clones allowed for a single function "
     "specialization"));
 
+static cl::opt<unsigned> MaxDiscoveryDepth(
+    "funcspec-max-discovery-depth", cl::init(10), cl::Hidden,
+    cl::desc("The maximum recursion depth allowed when searching for strongly "
+             "connected phis"));
+
 static cl::opt<unsigned> MaxIncomingPhiValues(
-    "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
-    "The maximum number of incoming values a PHI node can have to be "
-    "considered during the specialization bonus estimation"));
+    "funcspec-max-incoming-phi-values", cl::init(8), cl::Hidden,
+    cl::desc("The maximum number of incoming values a PHI node can have to be "
+             "considered during the specialization bonus estimation"));
 
 static cl::opt<unsigned> MaxBlockPredecessors(
     "funcspec-max-block-predecessors", cl::init(2), cl::Hidden, cl::desc(
@@ -64,9 +69,9 @@ static cl::opt<unsigned> MinCodeSizeSavings(
     "much percent of the original function size"));
 
 static cl::opt<unsigned> MinLatencySavings(
-    "funcspec-min-latency-savings", cl::init(70), cl::Hidden, cl::desc(
-    "Reject specializations whose latency savings are less than this"
-    "much percent of the original function size"));
+    "funcspec-min-latency-savings", cl::init(45), cl::Hidden,
+    cl::desc("Reject specializations whose latency savings are less than this"
+             "much percent of the original function size"));
 
 static cl::opt<unsigned> MinInliningBonus(
     "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, cl::desc(
@@ -262,30 +267,86 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
   return estimateBasicBlocks(WorkList);
 }
 
+void InstCostVisitor::discoverStronglyConnectedComponent(PHINode *PN,
+                                                         unsigned Depth) {
+  if (Depth > MaxDiscoveryDepth)
+    return;
+
+  if (PN->getNumIncomingValues() > MaxIncomingPhiValues)
+    return;
+
+  if (!StronglyConnectedPHIs.insert(PN).second)
+    return;
+
+  for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
+    Value *V = PN->getIncomingValue(I);
+    if (auto *Phi = dyn_cast<PHINode>(V)) {
+      if (Phi == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
+        continue;
+      discoverStronglyConnectedComponent(Phi, Depth + 1);
+    }
+  }
+}
+
 Constant *InstCostVisitor::visitPHINode(PHINode &I) {
   if (I.getNumIncomingValues() > MaxIncomingPhiValues)
     return nullptr;
 
   bool Inserted = VisitedPHIs.insert(&I).second;
   Constant *Const = nullptr;
+  SmallVector<PHINode *, 8> UnknownIncomingValues;
 
-  for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
-    Value *V = I.getIncomingValue(Idx);
-    if (auto *Inst = dyn_cast<Instruction>(V))
-      if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
-        continue;
-    Constant *C = findConstantFor(V, KnownConstants);
-    if (!C) {
-      if (Inserted)
-        PendingPHIs.push_back(&I);
-      return nullptr;
+  auto CanConstantFoldPhi = [&](PHINode *PN) -> bool {
+    UnknownIncomingValues.clear();
+
+    for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
+      Value *V = PN->getIncomingValue(I);
+
+      // Disregard self-references and dead incoming values.
+      if (auto *Inst = dyn_cast<Instruction>(V))
+        if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
+          continue;
+
+      if (Constant *C = findConstantFor(V, KnownConstants)) {
+        if (!Const)
+          Const = C;
+        // Not all incoming values are the same constant. Bail immediately.
+        else if (C != Const)
+          return false;
+      } else if (auto *Phi = dyn_cast<PHINode>(V)) {
+        // It's not a strongly connected phi. Collect it and bail at the end.
+        if (!StronglyConnectedPHIs.contains(Phi))
+          UnknownIncomingValues.push_back(Phi);
+      } else {
+        // We can't reason about anything else.
+        return false;
+      }
+    }
+    return UnknownIncomingValues.empty();
+  };
+
+  if (CanConstantFoldPhi(&I))
+    return Const;
+
+  if (Inserted) {
+    // First time we are seeing this phi. We'll retry later, after all
+    // the constant arguments have been propagated. Bail for now.
+    PendingPHIs.push_back(&I);
+    return nullptr;
+  }
+
+  for (PHINode *Phi : UnknownIncomingValues)
+    discoverStronglyConnectedComponent(Phi, 1);
+
+  bool CannotConstantFoldPhi = false;
+  for (PHINode *Phi : StronglyConnectedPHIs) {
+    if (!CanConstantFoldPhi(Phi)) {
+      CannotConstantFoldPhi = true;
+      break;
     }
-    if (!Const)
-      Const = C;
-    else if (C != Const)
-      return nullptr;
   }
-  return Const;
+  StronglyConnectedPHIs.clear();
+  return CannotConstantFoldPhi ? nullptr : Const;
 }
 
 Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
@@ -809,20 +870,40 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize,
                              unsigned FuncGrowth) -> bool {
         // No check required.
-        if (ForceSpecialization)
+        if (ForceSpecialization) {
+          LLVM_DEBUG(dbgs() << "Force is on\n");
           return true;
+        }
         // Minimum inlining bonus.
-        if (Score > MinInliningBonus * FuncSize / 100)
+        if (Score > MinInliningBonus * FuncSize / 100) {
+          LLVM_DEBUG(dbgs()
+                     << "FnSpecialization: Min inliningbous: Score = " << Score
+                     << " > " << MinInliningBonus * FuncSize / 100 << "\n");
           return true;
+        }
         // Minimum codesize savings.
-        if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100)
+        if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) {
+          LLVM_DEBUG(dbgs()
+                     << "FnSpecialization: Min CodeSize Saving: CodeSize = "
+                     << B.CodeSize << " > "
+                     << MinCodeSizeSavings * FuncSize / 100 << "\n");
           return false;
+        }
         // Minimum latency savings.
-        if (B.Latency < MinLatencySavings * FuncSize / 100)
+        if (B.Latency < MinLatencySavings * FuncSize / 100) {
+          LLVM_DEBUG(dbgs()
+                     << "FnSpecialization: Min Latency Saving: Latency = "
+                     << B.Latency << " > " << MinLatencySavings * FuncSize / 100
+                     << "\n");
           return false;
+        }
         // Maximum codesize growth.
-        if (FuncGrowth / FuncSize > MaxCodeSizeGrowth)
+        if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) {
+          LLVM_DEBUG(dbgs() << "FnSpecialization: Max Func Growth: CodeSize = "
+                            << FuncGrowth / FuncSize << " > "
+                            << MaxCodeSizeGrowth << "\n");
           return false;
+        }
         return true;
       };
 
diff --git a/llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll b/llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll
new file mode 100644
index 000000000000000..3463ddb6f066de8
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+;
+; RUN: opt -passes="ipsccp<func-spec>" -funcspec-min-function-size=20 -funcspec-for-literal-constant -S < %s | FileCheck %s --check-prefix=FUNCSPEC
+; RUN: opt -passes="ipsccp<func-spec>" -funcspec-min-function-size=20 -funcspec-for-literal-constant -funcspec-max-discovery-depth=5 -S < %s | FileCheck %s --check-prefix=NOFUNCSPEC
+
+define i64 @bar(i1 %c1, i1 %c2, i1 %c3, i1 %c4, i1 %c5, i1 %c6, i1 %c7, i1 %c8, i1 %c9, i1 %c10) {
+; FUNCSPEC-LABEL: define i64 @bar(
+; FUNCSPEC-SAME: i1 [[C1:%.*]], i1 [[C2:%.*]], i1 [[C3:%.*]], i1 [[C4:%.*]], i1 [[C5:%.*]], i1 [[C6:%.*]], i1 [[C7:%.*]], i1 [[C8:%.*]], i1 [[C9:%.*]], i1 [[C10:%.*]]) {
+; FUNCSPEC-NEXT:  entry:
+; FUNCSPEC-NEXT:    [[F1:%.*]] = call i64 @foo.specialized.1(i64 3, i1 [[C1]], i1 [[C2]], i1 [[C3]], i1 [[C4]], i1 [[C5]], i1 [[C6]], i1 [[C7]], i1 [[C8]], i1 [[C9]], i1 [[C10]]), !range [[RNG0:![0-9]+]]
+; FUNCSPEC-NEXT:    [[F2:%.*]] = call i64 @foo.specialized.2(i64 4, i1 [[C1]], i1 [[C2]], i1 [[C3]], i1 [[C4]], i1 [[C5]], i1 [[C6]], i1 [[C7]], i1 [[C8]], i1 [[C9]], i1 [[C10]]), !range [[RNG1:![0-9]+]]
+; FUNCSPEC-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[F1]], [[F2]]
+; FUNCSPEC-NEXT:    ret i64 [[ADD]]
+;
+; NOFUNCSPEC-LABEL: define i64 @bar(
+; NOFUNCSPEC-SAME: i1 [[C1:%.*]], i1 [[C2:%.*]], i1 [[C3:%.*]], i1 [[C4:%.*]], i1 [[C5:%.*]], i1 [[C6:%.*]], i1 [[C7:%.*]], i1 [[C8:%.*]], i1 [[C9:%.*]], i1 [[C10:%.*]]) {
+; NOFUNCSPEC-NEXT:  entry:
+; NOFUNCSPEC-NEXT:    [[F1:%.*]] = call i64 @foo(i64 3, i1 [[C1]], i1 [[C2]], i1 [[C3]], i1 [[C4]], i1 [[C5]], i1 [[C6]], i1 [[C7]], i1 [[C8]], i1 [[C9]], i1 [[C10]]), !range [[RNG0:![0-9]+]]
+; NOFUNCSPEC-NEXT:    [[F2:%.*]] = call i64 @foo(i64 4, i1 [[C1]], i1 [[C2]], i1 [[C3]], i1 [[C4]], i1 [[C5]], i1 [[C6]], i1 [[C7]], i1 [[C8]], i1 [[C9]], i1 [[C10]]), !range [[RNG0]]
+; NOFUNCSPEC-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[F1]], [[F2]]
+; NOFUNCSPEC-NEXT:    ret i64 [[ADD]]
+;
+entry:
+  %f1 = call i64 @foo(i64 3, i1 %c1, i1 %c2, i1 %c3, i1 %c4, i1 %c5, i1 %c6, i1 %c7, i1 %c8, i1 %c9, i1 %c10)
+  %f2 = call i64 @foo(i64 4, i1 %c1, i1 %c2, i1 %c3, i1 %c4, i1 %c5, i1 %c6, i1 %c7, i1 %c8, i1 %c9, i1 %c10)
+  %add = add i64 %f1, %f2
+  ret i64 %add
+}
+
+define internal i64 @foo(i64 %n, i1 %c1, i1 %c2, i1 %c3, i1 %c4, i1 %c5, i1 %c6, i1 %c7, i1 %c8, i1 %c9, i1 %c10) {
+entry:
+  br i1 %c1, label %l1, label %l9
+
+l1:
+  %phi1 = phi i64 [ %n, %entry ], [ %phi2, %l2 ]
+  %add = add i64 %phi1, 1
+  %div = sdiv i64 %add, 2
+  br i1 %c2, label %l1_5, label %exit
+
+l1_5:
+  br i1 %c3, label %l1_75, label %l6
+
+l1_75:
+  br i1 %c4, label %l2, label %l3
+
+l2:
+  %phi2 = phi i64 [ %phi1, %l1_75 ], [ %phi3, %l3 ]
+  br label %l1
+
+l3:
+  %phi3 = phi i64 [ %phi1, %l1_75 ], [ %phi4, %l4 ]
+  br label %l2
+
+l4:
+  %phi4 = phi i64 [ %phi5, %l5 ], [ %phi6, %l6 ]
+  br i1 %c5, label %l3, label %l6
+
+l5:
+  %phi5 = phi i64 [ %phi6, %l6_5 ], [ %phi7, %l7 ]
+  br label %l4
+
+l6:
+  %phi6 = phi i64 [ %phi4, %l4 ], [ %phi1, %l1_5 ]
+  br i1 %c6, label %l4, label %l6_5
+
+l6_5:
+  br i1 %c7, label %l5, label %l8
+
+l7:
+  %phi7 = phi i64 [ %phi9, %l9 ], [ %phi8, %l8 ]
+  br i1 %c8, label %l5, label %l8
+
+l8:
+  %phi8 = phi i64 [ %phi6, %l6_5 ], [ %phi7, %l7 ]
+  br i1 %c9, label %l7, label %l9
+
+l9:
+  %phi9 = phi i64 [ %n, %entry ], [ %phi8, %l8 ]
+  %sub = sub i64 %phi9, 1
+  %mul = mul i64 %sub, 2
+  br i1 %c10, label %l7, label %exit
+
+exit:
+  %res = phi i64 [ %div, %l1 ], [ %mul, %l9]
+  ret i64 %res
+}
+

>From 3fb7efdfeeab591bc943f23fe6899538c6e44411 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 7 Nov 2023 21:12:33 +0000
Subject: [PATCH 2/8] Update based on review comments

NOTE: We need to re-write the overall commit message, as it is
not close to accurate any longer.
---
 .../Transforms/IPO/FunctionSpecialization.h   |   5 +-
 .../Transforms/IPO/FunctionSpecialization.cpp | 167 +++++++++++++-----
 2 files changed, 126 insertions(+), 46 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index f35543cb8411b35..86cfcc0a5a77be0 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -183,8 +183,6 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   DenseSet<BasicBlock *> DeadBlocks;
   // PHI nodes we have visited before.
   DenseSet<Instruction *> VisitedPHIs;
-  // PHI nodes forming a strongly connected component.
-  DenseSet<PHINode *> StronglyConnectedPHIs;
   // PHI nodes we have visited once without successfully constant folding them.
   // Once the InstCostVisitor has processed all the specialization arguments,
   // it should be possible to determine whether those PHIs can be folded
@@ -219,7 +217,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   Cost estimateSwitchInst(SwitchInst &I);
   Cost estimateBranchInst(BranchInst &I);
 
-  void discoverStronglyConnectedComponent(PHINode *PN, unsigned Depth);
+  bool discoverTransitivelyIncomngValues(DenseSet<PHINode *> &PhiNodes,
+                                         PHINode *PN, unsigned Depth);
 
   Constant *visitInstruction(Instruction &I) { return nullptr; }
   Constant *visitPHINode(PHINode &I);
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 23e665a1901b5e1..9af4676bc2d51af 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -267,38 +267,77 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
   return estimateBasicBlocks(WorkList);
 }
 
-void InstCostVisitor::discoverStronglyConnectedComponent(PHINode *PN,
-                                                         unsigned Depth) {
-  if (Depth > MaxDiscoveryDepth)
-    return;
+// This function is finding candidates for a PHINode is part of a chain or graph
+// of PHINodes that all link to each other. That means, if the original input to
+// the chain is a constant all the other values are also that constant.
+//
+// The caller of this function will later check that no other nodes are involved
+// that are non-constant, and discard it from the possible conversions.
+//
+// For example:
+//
+// %a = load %0
+// %c = phi [%a, %d]
+// %d = phi [%e, %c]
+// %e = phi [%c, %f]
+// %f = phi [%j, %h]
+// %j = phi [%h, %j]
+// %h = phi [%g, %c]
+//
+// This is only showing the PHINodes, not the branches that choose the
+// different paths.
+//
+// A depth limit is used to avoid extreme recurusion.
+// A max number of incoming phi values ensures that expensive searches
+// are avoided.
+//
+// Returns false if the discovery was aborted due to the above conditions.
+bool InstCostVisitor::discoverTransitivelyIncomngValues(
+    DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
+  if (Depth > MaxDiscoveryDepth) {
+    LLVM_DEBUG(dbgs() << "FnSpecialization: Discover PHI nodes too deep ("
+                      << Depth << ">" << MaxDiscoveryDepth << ")\n");
+    return false;
+  }
 
-  if (PN->getNumIncomingValues() > MaxIncomingPhiValues)
-    return;
+  if (PN->getNumIncomingValues() > MaxIncomingPhiValues) {
+    LLVM_DEBUG(
+        dbgs() << "FnSpecialization: Discover PHI nodes has too many values  ("
+               << PN->getNumIncomingValues() << ">" << MaxIncomingPhiValues
+               << ")\n");
+    return false;
+  }
 
-  if (!StronglyConnectedPHIs.insert(PN).second)
-    return;
+  // Already seen this, no more processing needed.
+  if (!PHINodes.insert(PN).second)
+    return true;
 
   for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
     Value *V = PN->getIncomingValue(I);
     if (auto *Phi = dyn_cast<PHINode>(V)) {
       if (Phi == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
         continue;
-      discoverStronglyConnectedComponent(Phi, Depth + 1);
+      if (!discoverTransitivelyIncomngValues(PHINodes, Phi, Depth + 1))
+        return false;
     }
   }
+  return true;
 }
 
 Constant *InstCostVisitor::visitPHINode(PHINode &I) {
   if (I.getNumIncomingValues() > MaxIncomingPhiValues)
     return nullptr;
 
+  // PHI nodes
+  DenseSet<PHINode *> TransitivePHIs;
+
   bool Inserted = VisitedPHIs.insert(&I).second;
-  Constant *Const = nullptr;
   SmallVector<PHINode *, 8> UnknownIncomingValues;
 
-  auto CanConstantFoldPhi = [&](PHINode *PN) -> bool {
-    UnknownIncomingValues.clear();
+  auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
+    Constant *Const = nullptr;
 
+    UnknownIncomingValues.clear();
     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
       Value *V = PN->getIncomingValue(I);
 
@@ -311,21 +350,22 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
         if (!Const)
           Const = C;
         // Not all incoming values are the same constant. Bail immediately.
-        else if (C != Const)
-          return false;
-      } else if (auto *Phi = dyn_cast<PHINode>(V)) {
-        // It's not a strongly connected phi. Collect it and bail at the end.
-        if (!StronglyConnectedPHIs.contains(Phi))
-          UnknownIncomingValues.push_back(Phi);
-      } else {
-        // We can't reason about anything else.
-        return false;
+        if (C != Const)
+          return nullptr;
+        continue;
       }
+      if (auto *Phi = dyn_cast<PHINode>(V)) {
+        UnknownIncomingValues.push_back(Phi);
+        continue;
+      }
+
+      // We can't reason about anything else.
+      return nullptr;
     }
-    return UnknownIncomingValues.empty();
+    return UnknownIncomingValues.empty() ? Const : nullptr;
   };
 
-  if (CanConstantFoldPhi(&I))
+  if (Constant *Const = canConstantFoldPhiTrivially(&I))
     return Const;
 
   if (Inserted) {
@@ -335,18 +375,59 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
     return nullptr;
   }
 
+  // Try to see if we can collect a nest of transitive phis. Bail if
+  // it's too complex.
   for (PHINode *Phi : UnknownIncomingValues)
-    discoverStronglyConnectedComponent(Phi, 1);
+    if (!discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1))
+      return nullptr;
+
+  // A nested set of PHINodes can be constantfolded if:
+  // - It has a constant input.
+  // - It is always the SAME constant.
+  auto canConstantFoldNestedPhi = [&](PHINode *PN) -> Constant * {
+    Constant *Const = nullptr;
 
-  bool CannotConstantFoldPhi = false;
-  for (PHINode *Phi : StronglyConnectedPHIs) {
-    if (!CanConstantFoldPhi(Phi)) {
-      CannotConstantFoldPhi = true;
-      break;
+    for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
+      Value *V = PN->getIncomingValue(I);
+
+      // Disregard self-references and dead incoming values.
+      if (auto *Inst = dyn_cast<Instruction>(V))
+        if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
+          continue;
+
+      if (Constant *C = findConstantFor(V, KnownConstants)) {
+        if (!Const)
+          Const = C;
+        // Not all incoming values are the same constant. Bail immediately.
+        if (C != Const)
+          return nullptr;
+        continue;
+      }
+      if (auto *Phi = dyn_cast<PHINode>(V)) {
+        // It's not a Transitive phi. Bail out.
+        if (!TransitivePHIs.contains(Phi))
+          return nullptr;
+        continue;
+      }
+
+      // We can't reason about anything else.
+      return nullptr;
+    }
+    return Const;
+  };
+
+  // All TransitivePHIs have to be the SAME constant.
+  Constant *Retval = nullptr;
+  for (PHINode *Phi : TransitivePHIs) {
+    if (Constant *Const = canConstantFoldNestedPhi(Phi)) {
+      if (!Retval)
+        Retval = Const;
+      else if (Retval != Const)
+        return nullptr;
     }
   }
-  StronglyConnectedPHIs.clear();
-  return CannotConstantFoldPhi ? nullptr : Const;
+
+  return Retval;
 }
 
 Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
@@ -871,37 +952,37 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
                              unsigned FuncGrowth) -> bool {
         // No check required.
         if (ForceSpecialization) {
-          LLVM_DEBUG(dbgs() << "Force is on\n");
+          LLVM_DEBUG(dbgs() << "FnSpecialization: Force is on\n");
           return true;
         }
         // Minimum inlining bonus.
         if (Score > MinInliningBonus * FuncSize / 100) {
           LLVM_DEBUG(dbgs()
-                     << "FnSpecialization: Min inliningbous: Score = " << Score
-                     << " > " << MinInliningBonus * FuncSize / 100 << "\n");
+                     << "FnSpecialization: Sufficient inlining bonus (" << Score
+                     << " > " << MinInliningBonus * FuncSize / 100 << ")\n");
           return true;
         }
         // Minimum codesize savings.
         if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) {
           LLVM_DEBUG(dbgs()
-                     << "FnSpecialization: Min CodeSize Saving: CodeSize = "
+                     << "FnSpecialization: Insufficinet CodeSize Saving ("
                      << B.CodeSize << " > "
-                     << MinCodeSizeSavings * FuncSize / 100 << "\n");
+                     << MinCodeSizeSavings * FuncSize / 100 << ")\n");
           return false;
         }
         // Minimum latency savings.
         if (B.Latency < MinLatencySavings * FuncSize / 100) {
-          LLVM_DEBUG(dbgs()
-                     << "FnSpecialization: Min Latency Saving: Latency = "
-                     << B.Latency << " > " << MinLatencySavings * FuncSize / 100
-                     << "\n");
+          LLVM_DEBUG(dbgs() << "FnSpecialization: Insufficinet Latency Saving ("
+                            << B.Latency << " > "
+                            << MinLatencySavings * FuncSize / 100 << ")\n");
           return false;
         }
         // Maximum codesize growth.
         if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) {
-          LLVM_DEBUG(dbgs() << "FnSpecialization: Max Func Growth: CodeSize = "
-                            << FuncGrowth / FuncSize << " > "
-                            << MaxCodeSizeGrowth << "\n");
+          LLVM_DEBUG(dbgs()
+                     << "FnSpecialization: Function Growth exceeds threshold ("
+                     << FuncGrowth / FuncSize << " > " << MaxCodeSizeGrowth
+                     << ")\n");
           return false;
         }
         return true;

>From c951ec9d8a3b8b9307f22512ba42a95036f6aa3b Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 10 Nov 2023 14:27:42 +0000
Subject: [PATCH 3/8] Fix some more debug output

---
 llvm/lib/Transforms/IPO/FunctionSpecialization.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 9af4676bc2d51af..e28d019fec89377 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -965,16 +965,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
         // Minimum codesize savings.
         if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) {
           LLVM_DEBUG(dbgs()
-                     << "FnSpecialization: Insufficinet CodeSize Saving ("
-                     << B.CodeSize << " > "
+                     << "FnSpecialization: Insufficient CodeSize Savings ("
+                     << B.CodeSize << " < "
                      << MinCodeSizeSavings * FuncSize / 100 << ")\n");
           return false;
         }
         // Minimum latency savings.
         if (B.Latency < MinLatencySavings * FuncSize / 100) {
-          LLVM_DEBUG(dbgs() << "FnSpecialization: Insufficinet Latency Saving ("
-                            << B.Latency << " > "
-                            << MinLatencySavings * FuncSize / 100 << ")\n");
+          LLVM_DEBUG(dbgs()
+                     << "FnSpecialization: Insufficient Latency Savings ("
+                     << B.Latency << " < " << MinLatencySavings * FuncSize / 100
+                     << ")\n");
           return false;
         }
         // Maximum codesize growth.

>From cc8ac518760b11a658289d78e278f346c8b487c1 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 13 Nov 2023 17:47:29 +0000
Subject: [PATCH 4/8] Don't bail out completely when unable to discover all
 phi-nodes

---
 .../Transforms/IPO/FunctionSpecialization.h    |  2 +-
 .../Transforms/IPO/FunctionSpecialization.cpp  | 18 +++++++-----------
 2 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index 86cfcc0a5a77be0..e9bc92599d9998b 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -217,7 +217,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   Cost estimateSwitchInst(SwitchInst &I);
   Cost estimateBranchInst(BranchInst &I);
 
-  bool discoverTransitivelyIncomngValues(DenseSet<PHINode *> &PhiNodes,
+  void discoverTransitivelyIncomngValues(DenseSet<PHINode *> &PhiNodes,
                                          PHINode *PN, unsigned Depth);
 
   Constant *visitInstruction(Instruction &I) { return nullptr; }
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index e28d019fec89377..1cbe567c9230afc 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -290,14 +290,12 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
 // A depth limit is used to avoid extreme recurusion.
 // A max number of incoming phi values ensures that expensive searches
 // are avoided.
-//
-// Returns false if the discovery was aborted due to the above conditions.
-bool InstCostVisitor::discoverTransitivelyIncomngValues(
+void InstCostVisitor::discoverTransitivelyIncomngValues(
     DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
   if (Depth > MaxDiscoveryDepth) {
     LLVM_DEBUG(dbgs() << "FnSpecialization: Discover PHI nodes too deep ("
                       << Depth << ">" << MaxDiscoveryDepth << ")\n");
-    return false;
+    return;
   }
 
   if (PN->getNumIncomingValues() > MaxIncomingPhiValues) {
@@ -305,23 +303,21 @@ bool InstCostVisitor::discoverTransitivelyIncomngValues(
         dbgs() << "FnSpecialization: Discover PHI nodes has too many values  ("
                << PN->getNumIncomingValues() << ">" << MaxIncomingPhiValues
                << ")\n");
-    return false;
+    return;
   }
 
   // Already seen this, no more processing needed.
   if (!PHINodes.insert(PN).second)
-    return true;
+    return;
 
   for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
     Value *V = PN->getIncomingValue(I);
     if (auto *Phi = dyn_cast<PHINode>(V)) {
       if (Phi == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
         continue;
-      if (!discoverTransitivelyIncomngValues(PHINodes, Phi, Depth + 1))
-        return false;
+      discoverTransitivelyIncomngValues(PHINodes, Phi, Depth + 1);
     }
   }
-  return true;
 }
 
 Constant *InstCostVisitor::visitPHINode(PHINode &I) {
@@ -378,8 +374,8 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
   // Try to see if we can collect a nest of transitive phis. Bail if
   // it's too complex.
   for (PHINode *Phi : UnknownIncomingValues)
-    if (!discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1))
-      return nullptr;
+    discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1);
+
 
   // A nested set of PHINodes can be constantfolded if:
   // - It has a constant input.

>From a942a55554d1fac6020ba3337458d7c23e251fc7 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 14 Nov 2023 17:14:04 +0000
Subject: [PATCH 5/8] Adjust latency limit

---
 llvm/lib/Transforms/IPO/FunctionSpecialization.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 1cbe567c9230afc..dc28837440dc604 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -69,7 +69,7 @@ static cl::opt<unsigned> MinCodeSizeSavings(
     "much percent of the original function size"));
 
 static cl::opt<unsigned> MinLatencySavings(
-    "funcspec-min-latency-savings", cl::init(45), cl::Hidden,
+    "funcspec-min-latency-savings", cl::init(40), cl::Hidden,
     cl::desc("Reject specializations whose latency savings are less than this"
              "much percent of the original function size"));
 

>From ad3bf332456a1c6231c6cdd2734ba141cf08b009 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 15 Nov 2023 16:14:04 +0000
Subject: [PATCH 6/8] Improve constant-foldign of nested phinodes.

---
 .../Transforms/IPO/FunctionSpecialization.cpp | 43 +++++++++++++------
 1 file changed, 29 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index dc28837440dc604..9293aaf17c27ce2 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -371,56 +371,71 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
     return nullptr;
   }
 
-  // Try to see if we can collect a nest of transitive phis. Bail if
-  // it's too complex.
+  // Try to see if we can collect a nest of transitive phis.
   for (PHINode *Phi : UnknownIncomingValues)
     discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1);
 
-
   // A nested set of PHINodes can be constantfolded if:
   // - It has a constant input.
   // - It is always the SAME constant.
-  auto canConstantFoldNestedPhi = [&](PHINode *PN) -> Constant * {
-    Constant *Const = nullptr;
+  // - All the nodes are part of the nest, or a constant.
+  // Later we will check that the constant is always the same one.
+  Constant *Const = nullptr;
+  enum FoldStatus {
+    Failed,    // Stop, this can't be folded.
+    KeepGoing, // Maybe can be folded, didn't find a constant.
+    FoundConst // Maybe can be folded, we found constant.
+  };
+  auto canConstantFoldNestedPhi = [&](PHINode *PN) -> FoldStatus {
+    FoldStatus Status = KeepGoing;
 
     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
       Value *V = PN->getIncomingValue(I);
-
       // Disregard self-references and dead incoming values.
       if (auto *Inst = dyn_cast<Instruction>(V))
         if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
           continue;
 
       if (Constant *C = findConstantFor(V, KnownConstants)) {
-        if (!Const)
+        if (!Const) {
           Const = C;
+          Status = FoundConst;
+        }
         // Not all incoming values are the same constant. Bail immediately.
         if (C != Const)
-          return nullptr;
+          return Failed;
         continue;
       }
       if (auto *Phi = dyn_cast<PHINode>(V)) {
         // It's not a Transitive phi. Bail out.
         if (!TransitivePHIs.contains(Phi))
-          return nullptr;
+          return Failed;
         continue;
       }
 
       // We can't reason about anything else.
-      return nullptr;
+      return Failed;
     }
-    return Const;
+    return Status;
   };
 
   // All TransitivePHIs have to be the SAME constant.
   Constant *Retval = nullptr;
   for (PHINode *Phi : TransitivePHIs) {
-    if (Constant *Const = canConstantFoldNestedPhi(Phi)) {
-      if (!Retval)
+    FoldStatus Status = canConstantFoldNestedPhi(Phi);
+    if (Status == FoundConst) {
+      if (!Retval) {
         Retval = Const;
-      else if (Retval != Const)
+        continue;
+      }
+      // Found more than one constant, can't fold.
+      if (Retval != Const)
         return nullptr;
     }
+    // Found something "wrong", can't fold.
+    else if (Status == Failed)
+      return nullptr;
+    assert(Status == KeepGoing && "Status should be KeepGoing here");
   }
 
   return Retval;

>From c22b3fedfc362acb8464bec9774f9a45abe23fe5 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 16 Nov 2023 12:28:17 +0000
Subject: [PATCH 7/8] Update canConstantFold lambdas

---
 .../Transforms/IPO/FunctionSpecialization.h   |  4 +-
 .../Transforms/IPO/FunctionSpecialization.cpp | 64 ++++++++-----------
 ...nected-phis.ll => discover-nested-phis.ll} |  0
 3 files changed, 28 insertions(+), 40 deletions(-)
 rename llvm/test/Transforms/FunctionSpecialization/{discover-strongly-connected-phis.ll => discover-nested-phis.ll} (100%)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index e9bc92599d9998b..0629a4789e59c6a 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -217,8 +217,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   Cost estimateSwitchInst(SwitchInst &I);
   Cost estimateBranchInst(BranchInst &I);
 
-  void discoverTransitivelyIncomngValues(DenseSet<PHINode *> &PhiNodes,
-                                         PHINode *PN, unsigned Depth);
+  void discoverTransitivelyIncomingValues(DenseSet<PHINode *> &PhiNodes,
+                                          PHINode *PN, unsigned Depth);
 
   Constant *visitInstruction(Instruction &I) { return nullptr; }
   Constant *visitPHINode(PHINode &I);
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 9293aaf17c27ce2..0b0648e0836a2b1 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -290,7 +290,7 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
 // A depth limit is used to avoid extreme recurusion.
 // A max number of incoming phi values ensures that expensive searches
 // are avoided.
-void InstCostVisitor::discoverTransitivelyIncomngValues(
+void InstCostVisitor::discoverTransitivelyIncomingValues(
     DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
   if (Depth > MaxDiscoveryDepth) {
     LLVM_DEBUG(dbgs() << "FnSpecialization: Discover PHI nodes too deep ("
@@ -315,7 +315,7 @@ void InstCostVisitor::discoverTransitivelyIncomngValues(
     if (auto *Phi = dyn_cast<PHINode>(V)) {
       if (Phi == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
         continue;
-      discoverTransitivelyIncomngValues(PHINodes, Phi, Depth + 1);
+      discoverTransitivelyIncomingValues(PHINodes, Phi, Depth + 1);
     }
   }
 }
@@ -328,12 +328,10 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
   DenseSet<PHINode *> TransitivePHIs;
 
   bool Inserted = VisitedPHIs.insert(&I).second;
-  SmallVector<PHINode *, 8> UnknownIncomingValues;
 
   auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
     Constant *Const = nullptr;
 
-    UnknownIncomingValues.clear();
     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
       Value *V = PN->getIncomingValue(I);
 
@@ -350,15 +348,11 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
           return nullptr;
         continue;
       }
-      if (auto *Phi = dyn_cast<PHINode>(V)) {
-        UnknownIncomingValues.push_back(Phi);
-        continue;
-      }
 
       // We can't reason about anything else.
       return nullptr;
     }
-    return UnknownIncomingValues.empty() ? Const : nullptr;
+    return Const;
   };
 
   if (Constant *Const = canConstantFoldPhiTrivially(&I))
@@ -371,24 +365,18 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
     return nullptr;
   }
 
-  // Try to see if we can collect a nest of transitive phis.
-  for (PHINode *Phi : UnknownIncomingValues)
-    discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1);
+  for (unsigned J = 0, E = I.getNumIncomingValues(); J != E; ++J) {
+    Value *V = I.getIncomingValue(J);
+    if (auto *Phi = dyn_cast<PHINode>(V))
+      discoverTransitivelyIncomingValues(TransitivePHIs, Phi, 1);
+  }
 
   // A nested set of PHINodes can be constantfolded if:
   // - It has a constant input.
   // - It is always the SAME constant.
   // - All the nodes are part of the nest, or a constant.
   // Later we will check that the constant is always the same one.
-  Constant *Const = nullptr;
-  enum FoldStatus {
-    Failed,    // Stop, this can't be folded.
-    KeepGoing, // Maybe can be folded, didn't find a constant.
-    FoundConst // Maybe can be folded, we found constant.
-  };
-  auto canConstantFoldNestedPhi = [&](PHINode *PN) -> FoldStatus {
-    FoldStatus Status = KeepGoing;
-
+  auto canConstantFoldNestedPhi = [&](PHINode *PN, Constant *&Const) -> bool {
     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
       Value *V = PN->getIncomingValue(I);
       // Disregard self-references and dead incoming values.
@@ -397,45 +385,45 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
           continue;
 
       if (Constant *C = findConstantFor(V, KnownConstants)) {
-        if (!Const) {
+        if (!Const)
           Const = C;
-          Status = FoundConst;
-        }
+
         // Not all incoming values are the same constant. Bail immediately.
         if (C != Const)
-          return Failed;
+          return false;
         continue;
       }
       if (auto *Phi = dyn_cast<PHINode>(V)) {
         // It's not a Transitive phi. Bail out.
         if (!TransitivePHIs.contains(Phi))
-          return Failed;
+          return false;
         continue;
       }
 
       // We can't reason about anything else.
-      return Failed;
+      return false;
     }
-    return Status;
+    return true;
   };
 
   // All TransitivePHIs have to be the SAME constant.
   Constant *Retval = nullptr;
   for (PHINode *Phi : TransitivePHIs) {
-    FoldStatus Status = canConstantFoldNestedPhi(Phi);
-    if (Status == FoundConst) {
-      if (!Retval) {
-        Retval = Const;
-        continue;
+    Constant *Const = nullptr;
+    if (canConstantFoldNestedPhi(Phi, Const)) {
+      if (Const) {
+        if (!Retval) {
+          Retval = Const;
+          continue;
+        }
+        // Found more than one constant, can't fold.
+        if (Retval != Const)
+          return nullptr;
       }
-      // Found more than one constant, can't fold.
-      if (Retval != Const)
-        return nullptr;
     }
     // Found something "wrong", can't fold.
-    else if (Status == Failed)
+    else
       return nullptr;
-    assert(Status == KeepGoing && "Status should be KeepGoing here");
   }
 
   return Retval;
diff --git a/llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll b/llvm/test/Transforms/FunctionSpecialization/discover-nested-phis.ll
similarity index 100%
rename from llvm/test/Transforms/FunctionSpecialization/discover-strongly-connected-phis.ll
rename to llvm/test/Transforms/FunctionSpecialization/discover-nested-phis.ll

>From e5b5bba914236e303f5cd2f16480ec98573f92d8 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 16 Nov 2023 16:01:43 +0000
Subject: [PATCH 8/8] Revert small lambda change that broke things

---
 .../Transforms/IPO/FunctionSpecialization.cpp    | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 0b0648e0836a2b1..2f8632e760a9b08 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -328,10 +328,12 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
   DenseSet<PHINode *> TransitivePHIs;
 
   bool Inserted = VisitedPHIs.insert(&I).second;
+  SmallVector<PHINode *, 8> UnknownIncomingValues;
 
   auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
     Constant *Const = nullptr;
 
+    UnknownIncomingValues.clear();
     for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
       Value *V = PN->getIncomingValue(I);
 
@@ -348,11 +350,15 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
           return nullptr;
         continue;
       }
+      if (auto *Phi = dyn_cast<PHINode>(V)) {
+        UnknownIncomingValues.push_back(Phi);
+        continue;
+      }
 
       // We can't reason about anything else.
       return nullptr;
     }
-    return Const;
+    return UnknownIncomingValues.empty() ? Const : nullptr;
   };
 
   if (Constant *Const = canConstantFoldPhiTrivially(&I))
@@ -365,11 +371,9 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
     return nullptr;
   }
 
-  for (unsigned J = 0, E = I.getNumIncomingValues(); J != E; ++J) {
-    Value *V = I.getIncomingValue(J);
-    if (auto *Phi = dyn_cast<PHINode>(V))
-      discoverTransitivelyIncomingValues(TransitivePHIs, Phi, 1);
-  }
+  // Try to see if we can collect a nest of transitive phis.
+  for (PHINode *Phi : UnknownIncomingValues)
+    discoverTransitivelyIncomingValues(TransitivePHIs, Phi, 1);
 
   // A nested set of PHINodes can be constantfolded if:
   // - It has a constant input.



More information about the llvm-commits mailing list