[llvm] [SimpleLoopUnswitch] Remove callbacks (PR #73300)

Aiden Grossman via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 24 01:48:50 PST 2023


https://github.com/boomanaiden154 updated https://github.com/llvm/llvm-project/pull/73300

>From 352524cd350fcf5b193eb4dae67374dc58c93133 Mon Sep 17 00:00:00 2001
From: Aiden Grossman <agrossman154 at yahoo.com>
Date: Fri, 24 Nov 2023 01:16:09 -0800
Subject: [PATCH 1/2] [SimpleLoopUnswitch] Remove callbacks

After the removal of the legacyPM version of simple loop unswitch, there
is no longer a need for the callback mechanism to handle PM specific
tasks. This patch removes the callbacks to help simplify the code now
that they're no longer needed.
---
 .../Transforms/Scalar/SimpleLoopUnswitch.cpp  | 162 ++++++++----------
 1 file changed, 74 insertions(+), 88 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index 55606473765239b..cd00ef1fbbfc86b 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -1682,13 +1682,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
     BB->eraseFromParent();
 }
 
-static void
-deleteDeadBlocksFromLoop(Loop &L,
-                         SmallVectorImpl<BasicBlock *> &ExitBlocks,
-                         DominatorTree &DT, LoopInfo &LI,
-                         MemorySSAUpdater *MSSAU,
-                         ScalarEvolution *SE,
-                         function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static void deleteDeadBlocksFromLoop(Loop &L,
+                                     SmallVectorImpl<BasicBlock *> &ExitBlocks,
+                                     DominatorTree &DT, LoopInfo &LI,
+                                     MemorySSAUpdater *MSSAU,
+                                     ScalarEvolution *SE,
+                                     LPMUpdater &LoopUpdater) {
   // Find all the dead blocks tied to this loop, and remove them from their
   // successors.
   SmallSetVector<BasicBlock *, 8> DeadBlockSet;
@@ -1738,7 +1737,7 @@ deleteDeadBlocksFromLoop(Loop &L,
                         }) &&
            "If the child loop header is dead all blocks in the child loop must "
            "be dead as well!");
-    DestroyLoopCB(*ChildL, ChildL->getName());
+    LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName());
     if (SE)
       SE->forgetBlockAndLoopDispositions();
     LI.destroy(ChildL);
@@ -2082,8 +2081,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
       ParentL->removeChildLoop(llvm::find(*ParentL, &L));
     else
       LI.removeLoop(llvm::find(LI, &L));
-    // markLoopAsDeleted for L should be triggered by the caller (it is typically
-    // done by using the UnswitchCB callback).
+    // markLoopAsDeleted for L should be triggered by the caller (it is
+    // typically done within postUnswitch).
     if (SE)
       SE->forgetBlockAndLoopDispositions();
     LI.destroy(&L);
@@ -2120,18 +2119,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
   } while (!DomWorklist.empty());
 }
 
+void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
+                  bool CurrentLoopValid, bool PartiallyInvariant,
+                  bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
+  // If we did a non-trivial unswitch, we have added new (cloned) loops.
+  if (!NewLoops.empty())
+    U.addSiblingLoops(NewLoops);
+
+  // If the current loop remains valid, we should revisit it to catch any
+  // other unswitch opportunities. Otherwise, we need to mark it as deleted.
+  if (CurrentLoopValid) {
+    if (PartiallyInvariant) {
+      // Mark the new loop as partially unswitched, to avoid unswitching on
+      // the same condition again.
+      auto &Context = L.getHeader()->getContext();
+      MDNode *DisableUnswitchMD = MDNode::get(
+          Context,
+          MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
+      MDNode *NewLoopID = makePostTransformationMetadata(
+          Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
+          {DisableUnswitchMD});
+      L.setLoopID(NewLoopID);
+    } else if (InjectedCondition) {
+      // Do the same for injection of invariant conditions.
+      auto &Context = L.getHeader()->getContext();
+      MDNode *DisableUnswitchMD = MDNode::get(
+          Context,
+          MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
+      MDNode *NewLoopID = makePostTransformationMetadata(
+          Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
+          {DisableUnswitchMD});
+      L.setLoopID(NewLoopID);
+    } else
+      U.revisitCurrentLoop();
+  } else
+    U.markLoopAsDeleted(L, LoopName);
+}
+
 static void unswitchNontrivialInvariants(
     Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
     IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
-    AssumptionCache &AC,
-    function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-    ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-    function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze,
-    bool InjectedCondition) {
+    AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
+    LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) {
   auto *ParentBB = TI.getParent();
   BranchInst *BI = dyn_cast<BranchInst>(&TI);
   SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
 
+  // Save the current loop name in a variable so that we can report it even
+  // after it has been deleted.
+  std::string LoopName(L.getName());
+
   // We can only unswitch switches, conditional branches with an invariant
   // condition, or combining invariant conditions with an instruction or
   // partially invariant instructions.
@@ -2444,7 +2481,7 @@ static void unswitchNontrivialInvariants(
   // Now that our cloned loops have been built, we can update the original loop.
   // First we delete the dead blocks from it and then we rebuild the loop
   // structure taking these deletions into account.
-  deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB);
+  deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater);
 
   if (MSSAU && VerifyMemorySSA)
     MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -2580,7 +2617,8 @@ static void unswitchNontrivialInvariants(
   for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
     if (UpdatedL->getParentLoop() == ParentL)
       SibLoops.push_back(UpdatedL);
-  UnswitchCB(IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops);
+  postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant,
+               InjectedCondition, SibLoops);
 
   if (MSSAU && VerifyMemorySSA)
     MSSAU->getMemorySSA()->verifyMemorySSA();
@@ -3427,12 +3465,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
       Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
 }
 
-static bool unswitchBestCondition(
-    Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-    AAResults &AA, TargetTransformInfo &TTI,
-    function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-    ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-    function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                                  AssumptionCache &AC, AAResults &AA,
+                                  TargetTransformInfo &TTI, ScalarEvolution *SE,
+                                  MemorySSAUpdater *MSSAU,
+                                  LPMUpdater &LoopUpdater) {
   // Collect all invariant conditions within this loop (as opposed to an inner
   // loop which would be handled when visiting that inner loop).
   SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
@@ -3495,8 +3532,8 @@ static bool unswitchBestCondition(
   LLVM_DEBUG(dbgs() << "  Unswitching non-trivial (cost = " << Best.Cost
                     << ") terminator: " << *Best.TI << "\n");
   unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
-                               LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
-                               InsertFreeze, InjectedCondition);
+                               LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze,
+                               InjectedCondition);
   return true;
 }
 
@@ -3515,20 +3552,18 @@ static bool unswitchBestCondition(
 /// true, we will attempt to do non-trivial unswitching as well as trivial
 /// unswitching.
 ///
-/// The `UnswitchCB` callback provided will be run after unswitching is
-/// complete, with the first parameter set to `true` if the provided loop
-/// remains a loop, and a list of new sibling loops created.
+/// The `postUnswitch` function will be run after unswitching is complete
+/// with information on whether or not the provided loop remains a loop and
+/// a list of new sibling loops created.
 ///
 /// If `SE` is non-null, we will update that analysis based on the unswitching
 /// done.
-static bool
-unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-             AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
-             bool NonTrivial,
-             function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
-             ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
-             ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
-             function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
+static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                         AssumptionCache &AC, AAResults &AA,
+                         TargetTransformInfo &TTI, bool Trivial,
+                         bool NonTrivial, ScalarEvolution *SE,
+                         MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI,
+                         BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) {
   assert(L.isRecursivelyLCSSAForm(DT, LI) &&
          "Loops must be in LCSSA form before unswitching.");
 
@@ -3540,8 +3575,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
   if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
     // If we unswitched successfully we will want to clean up the loop before
     // processing it further so just mark it as unswitched and return.
-    UnswitchCB(/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
-               /*InjectedCondition*/ false, {});
+    postUnswitch(L, LoopUpdater, L.getName(), true, false, false, {});
     return true;
   }
 
@@ -3610,8 +3644,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
 
   // Try to unswitch the best invariant condition. We prefer this full unswitch to
   // a partial unswitch when possible below the threshold.
-  if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU,
-                            DestroyLoopCB))
+  if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater))
     return true;
 
   // No other opportunities to unswitch.
@@ -3631,52 +3664,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
   LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L
                     << "\n");
 
-  // Save the current loop name in a variable so that we can report it even
-  // after it has been deleted.
-  std::string LoopName = std::string(L.getName());
-
-  auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
-                                        bool PartiallyInvariant,
-                                        bool InjectedCondition,
-                                        ArrayRef<Loop *> NewLoops) {
-    // If we did a non-trivial unswitch, we have added new (cloned) loops.
-    if (!NewLoops.empty())
-      U.addSiblingLoops(NewLoops);
-
-    // If the current loop remains valid, we should revisit it to catch any
-    // other unswitch opportunities. Otherwise, we need to mark it as deleted.
-    if (CurrentLoopValid) {
-      if (PartiallyInvariant) {
-        // Mark the new loop as partially unswitched, to avoid unswitching on
-        // the same condition again.
-        auto &Context = L.getHeader()->getContext();
-        MDNode *DisableUnswitchMD = MDNode::get(
-            Context,
-            MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
-        MDNode *NewLoopID = makePostTransformationMetadata(
-            Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
-            {DisableUnswitchMD});
-        L.setLoopID(NewLoopID);
-      } else if (InjectedCondition) {
-        // Do the same for injection of invariant conditions.
-        auto &Context = L.getHeader()->getContext();
-        MDNode *DisableUnswitchMD = MDNode::get(
-            Context,
-            MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
-        MDNode *NewLoopID = makePostTransformationMetadata(
-            Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
-            {DisableUnswitchMD});
-        L.setLoopID(NewLoopID);
-      } else
-        U.revisitCurrentLoop();
-    } else
-      U.markLoopAsDeleted(L, LoopName);
-  };
-
-  auto DestroyLoopCB = [&U](Loop &L, StringRef Name) {
-    U.markLoopAsDeleted(L, Name);
-  };
-
   std::optional<MemorySSAUpdater> MSSAU;
   if (AR.MSSA) {
     MSSAU = MemorySSAUpdater(AR.MSSA);
@@ -3684,8 +3671,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
       AR.MSSA->verifyMemorySSA();
   }
   if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial,
-                    UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI,
-                    DestroyLoopCB))
+                    &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U))
     return PreservedAnalyses::all();
 
   if (AR.MSSA && VerifyMemorySSA)

>From 54341e06a305fcc097f893bd84342b3de3d54df2 Mon Sep 17 00:00:00 2001
From: Aiden Grossman <agrossman154 at yahoo.com>
Date: Fri, 24 Nov 2023 01:48:38 -0800
Subject: [PATCH 2/2] Readd parameter comments to postUnswitch call

---
 llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index cd00ef1fbbfc86b..9c61a5619388989 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -3575,7 +3575,10 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
   if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
     // If we unswitched successfully we will want to clean up the loop before
     // processing it further so just mark it as unswitched and return.
-    postUnswitch(L, LoopUpdater, L.getName(), true, false, false, {});
+    postUnswitch(/*Loop*/ L, /*LoopUpdater*/ LoopUpdater,
+                 /*LoopName*/ L.getName(),
+                 /*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
+                 /*InjectedCondition*/ false, /*NewLoops*/ {});
     return true;
   }
 



More information about the llvm-commits mailing list