[llvm] [Convergence] Extend cycles to include outside uses of tokens (PR #98006)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 8 02:46:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Sameer Sahasrabuddhe (ssahasra)

<details>
<summary>Changes</summary>

When a convergence control token T defined at an operation D in a cycle
C that is used by an operation U outside C, the cycle is said to be
extended up to U. This because the use of the convergence control T
requires that two threads that execute U must execute converged dynamic
instances of U if and only if they previously executed converged dynamic
instances of D.

For more information including a high-level C-like example, see
https://llvm.org/docs//ConvergentOperations.html

This change introduces a pass that captures this token semantics by
literally extending the cycle C to include every path from C to U.


---

Patch is 36.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98006.diff


11 Files Affected:

- (modified) llvm/include/llvm/ADT/GenericCycleImpl.h (+115-11) 
- (modified) llvm/include/llvm/ADT/GenericCycleInfo.h (+12-7) 
- (modified) llvm/include/llvm/ADT/GenericSSAContext.h (+4) 
- (added) llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h (+28) 
- (modified) llvm/lib/CodeGen/MachineSSAContext.cpp (+6) 
- (modified) llvm/lib/IR/SSAContext.cpp (+13) 
- (modified) llvm/lib/Passes/PassBuilder.cpp (+1) 
- (modified) llvm/lib/Passes/PassRegistry.def (+1) 
- (modified) llvm/lib/Transforms/Scalar/CMakeLists.txt (+1) 
- (added) llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp (+249) 
- (added) llvm/test/Transforms/CycleConvergenceExtend/basic.ll (+405) 


``````````diff
diff --git a/llvm/include/llvm/ADT/GenericCycleImpl.h b/llvm/include/llvm/ADT/GenericCycleImpl.h
index ab9c421a44693..447151ca33ee0 100644
--- a/llvm/include/llvm/ADT/GenericCycleImpl.h
+++ b/llvm/include/llvm/ADT/GenericCycleImpl.h
@@ -177,26 +177,41 @@ auto GenericCycleInfo<ContextT>::getTopLevelParentCycle(BlockT *Block)
 }
 
 template <typename ContextT>
-void GenericCycleInfo<ContextT>::moveTopLevelCycleToNewParent(CycleT *NewParent,
-                                                              CycleT *Child) {
-  assert((!Child->ParentCycle && !NewParent->ParentCycle) &&
-         "NewParent and Child must be both top level cycle!\n");
-  auto &CurrentContainer =
-      Child->ParentCycle ? Child->ParentCycle->Children : TopLevelCycles;
+void GenericCycleInfo<ContextT>::moveToAdjacentCycle(CycleT *NewParent,
+                                                     CycleT *Child) {
+  auto *OldParent = Child->getParentCycle();
+  assert(!OldParent || OldParent->contains(NewParent));
+
+  // Find the child in its current parent (or toplevel) and move it out of its
+  // container, into the new parent.
+  auto &CurrentContainer = OldParent ? OldParent->Children : TopLevelCycles;
   auto Pos = llvm::find_if(CurrentContainer, [=](const auto &Ptr) -> bool {
     return Child == Ptr.get();
   });
   assert(Pos != CurrentContainer.end());
   NewParent->Children.push_back(std::move(*Pos));
+  // Pos is empty after moving the child out. So we move the last child into its
+  // place rather than refilling the whole container.
   *Pos = std::move(CurrentContainer.back());
   CurrentContainer.pop_back();
+
   Child->ParentCycle = NewParent;
 
-  NewParent->Blocks.insert(Child->block_begin(), Child->block_end());
+  // Add child blocks to the hierarchy up to the old parent.
+  auto *ParentIter = NewParent;
+  while (ParentIter != OldParent) {
+    ParentIter->Blocks.insert(Child->block_begin(), Child->block_end());
+    ParentIter = ParentIter->getParentCycle();
+  }
 
-  for (auto &It : BlockMapTopLevel)
-    if (It.second == Child)
-      It.second = NewParent;
+  // If Child was a top-level cycle, update the map.
+  if (!OldParent) {
+    auto *H = NewParent->getHeader();
+    auto *NewTLC = getTopLevelParentCycle(H);
+    for (auto &It : BlockMapTopLevel)
+      if (It.second == Child)
+        It.second = NewTLC;
+  }
 }
 
 template <typename ContextT>
@@ -286,7 +301,7 @@ void GenericCycleInfoCompute<ContextT>::run(BlockT *EntryBlock) {
                      << "discovered child cycle "
                      << Info.Context.print(BlockParent->getHeader()) << "\n");
           // Make BlockParent the child of NewCycle.
-          Info.moveTopLevelCycleToNewParent(NewCycle.get(), BlockParent);
+          Info.moveToAdjacentCycle(NewCycle.get(), BlockParent);
 
           for (auto *ChildEntry : BlockParent->entries())
             ProcessPredecessors(ChildEntry);
@@ -409,6 +424,95 @@ void GenericCycleInfo<ContextT>::splitCriticalEdge(BlockT *Pred, BlockT *Succ,
   assert(validateTree());
 }
 
+/// \brief Extend a cycle minimally such that it contains every path from that
+///        cycle reaching a a given block.
+///
+/// The cycle structure is updated such that all predecessors of \p toBlock will
+/// be contained (possibly indirectly) in \p cycleToExtend, without removing any
+/// cycles.
+///
+/// If \p transferredBlocks is non-null, all blocks whose direct containing
+/// cycle was changed are appended to the vector.
+template <typename ContextT>
+void GenericCycleInfo<ContextT>::extendCycle(
+    CycleT *cycleToExtend, BlockT *toBlock,
+    SmallVectorImpl<BlockT *> *transferredBlocks) {
+  SmallVector<BlockT *> workList;
+  workList.push_back(toBlock);
+
+  assert(cycleToExtend);
+  while (!workList.empty()) {
+    BlockT *block = workList.pop_back_val();
+    CycleT *cycle = getCycle(block);
+    if (cycleToExtend->contains(cycle))
+      continue;
+
+    auto cycleToInclude = findLargestDisjointAncestor(cycle, cycleToExtend);
+    if (cycleToInclude) {
+      // Move cycle into cycleToExtend.
+      moveToAdjacentCycle(cycleToExtend, cycleToInclude);
+      assert(cycleToInclude->Depth <= cycleToExtend->Depth);
+      GenericCycleInfoCompute<ContextT>::updateDepth(cycleToInclude);
+
+      // Continue from the entries of the newly included cycle.
+      for (BlockT *entry : cycleToInclude->Entries)
+        llvm::append_range(workList, predecessors(entry));
+    } else {
+      // Block is contained in an ancestor of cycleToExtend, just add it
+      // to the cycle and proceed.
+      BlockMap[block] = cycleToExtend;
+      if (transferredBlocks)
+        transferredBlocks->push_back(block);
+
+      CycleT *ancestor = cycleToExtend;
+      do {
+        ancestor->Blocks.insert(block);
+        ancestor = ancestor->getParentCycle();
+      } while (ancestor != cycle);
+
+      llvm::append_range(workList, predecessors(block));
+    }
+  }
+
+  assert(validateTree());
+}
+
+/// \brief Finds the largest ancestor of \p A that is disjoint from \B.
+///
+/// The caller must ensure that \p B does not contain \p A. If \p A
+/// contains \p B, null is returned.
+template <typename ContextT>
+auto GenericCycleInfo<ContextT>::findLargestDisjointAncestor(
+    const CycleT *A, const CycleT *B) const -> CycleT * {
+  if (!A || !B)
+    return nullptr;
+
+  while (B && A->Depth < B->Depth)
+    B = B->ParentCycle;
+  while (A && A->Depth > B->Depth)
+    A = A->ParentCycle;
+
+  if (A == B)
+    return nullptr;
+
+  assert(A && B);
+  assert(A->Depth == B->Depth);
+
+  for (;;) {
+    // Since both are at the same depth, the only way for both A and B to be
+    // null is when their parents are null, which will terminate the loop.
+    assert(A && B);
+
+    if (A->ParentCycle == B->ParentCycle) {
+      // const_cast is justified since cycles are owned by this
+      // object, which is non-const.
+      return const_cast<CycleT *>(A);
+    }
+    A = A->ParentCycle;
+    B = B->ParentCycle;
+  }
+}
+
 /// \brief Find the innermost cycle containing a given block.
 ///
 /// \returns the innermost cycle containing \p Block or nullptr if
diff --git a/llvm/include/llvm/ADT/GenericCycleInfo.h b/llvm/include/llvm/ADT/GenericCycleInfo.h
index b601fc9bae38a..fd68bfe40ce64 100644
--- a/llvm/include/llvm/ADT/GenericCycleInfo.h
+++ b/llvm/include/llvm/ADT/GenericCycleInfo.h
@@ -250,13 +250,7 @@ template <typename ContextT> class GenericCycleInfo {
   ///
   /// Note: This is an incomplete operation that does not update the depth of
   /// the subtree.
-  void moveTopLevelCycleToNewParent(CycleT *NewParent, CycleT *Child);
-
-  /// Assumes that \p Cycle is the innermost cycle containing \p Block.
-  /// \p Block will be appended to \p Cycle and all of its parent cycles.
-  /// \p Block will be added to BlockMap with \p Cycle and
-  /// BlockMapTopLevel with \p Cycle's top level parent cycle.
-  void addBlockToCycle(BlockT *Block, CycleT *Cycle);
+  void moveToAdjacentCycle(CycleT *NewParent, CycleT *Child);
 
 public:
   GenericCycleInfo() = default;
@@ -275,6 +269,15 @@ template <typename ContextT> class GenericCycleInfo {
   unsigned getCycleDepth(const BlockT *Block) const;
   CycleT *getTopLevelParentCycle(BlockT *Block);
 
+  /// Assumes that \p Cycle is the innermost cycle containing \p Block.
+  /// \p Block will be appended to \p Cycle and all of its parent cycles.
+  /// \p Block will be added to BlockMap with \p Cycle and
+  /// BlockMapTopLevel with \p Cycle's top level parent cycle.
+  void addBlockToCycle(BlockT *Block, CycleT *Cycle);
+
+  void extendCycle(CycleT *cycleToExtend, BlockT *toBlock,
+                   SmallVectorImpl<BlockT *> *transferredBlocks = nullptr);
+
   /// Methods for debug and self-test.
   //@{
 #ifndef NDEBUG
@@ -285,6 +288,8 @@ template <typename ContextT> class GenericCycleInfo {
   Printable print(const CycleT *Cycle) { return Cycle->print(Context); }
   //@}
 
+  CycleT *findLargestDisjointAncestor(const CycleT *a, const CycleT *b) const;
+
   /// Iteration over top-level cycles.
   //@{
   using const_toplevel_iterator_base =
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..480fe1a8f1511 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -93,6 +93,9 @@ template <typename _FunctionT> class GenericSSAContext {
   static void appendBlockTerms(SmallVectorImpl<const InstructionT *> &terms,
                                const BlockT &block);
 
+  static void appendConvergenceTokenUses(std::vector<BlockT *> &Worklist,
+                                         BlockT &BB);
+
   static bool isConstantOrUndefValuePhi(const InstructionT &Instr);
   const BlockT *getDefBlock(ConstValueRefT value) const;
 
@@ -101,6 +104,7 @@ template <typename _FunctionT> class GenericSSAContext {
   Printable print(const InstructionT *inst) const;
   Printable print(ConstValueRefT value) const;
 };
+
 } // namespace llvm
 
 #endif // LLVM_ADT_GENERICSSACONTEXT_H
diff --git a/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h b/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h
new file mode 100644
index 0000000000000..0e39452c3f213
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h
@@ -0,0 +1,28 @@
+//===- CycleConvergenceExtend.h - Extend cycles for convergence -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides the interface for the CycleConvergenceExtend pass.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
+#define LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class CycleConvergenceExtendPass
+    : public PassInfoMixin<CycleConvergenceExtendPass> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
diff --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index e384187b6e859..200faf5a401dd 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -46,6 +46,12 @@ void MachineSSAContext::appendBlockTerms(
     terms.push_back(&T);
 }
 
+template <>
+void MachineSSAContext::appendConvergenceTokenUses(
+    std::vector<MachineBasicBlock *> &Worklist, MachineBasicBlock &BB) {
+  llvm_unreachable("Cycle extensions are not supported in MIR yet.");
+}
+
 /// Get the defining block of a value.
 template <>
 const MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const {
diff --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 220abe3083ebd..3d9fb6d05bc5a 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/ModuleSlotTracker.h"
 #include "llvm/Support/raw_ostream.h"
@@ -55,6 +56,18 @@ void SSAContext::appendBlockTerms(SmallVectorImpl<const Instruction *> &terms,
   terms.push_back(block.getTerminator());
 }
 
+template <>
+void SSAContext::appendConvergenceTokenUses(std::vector<BasicBlock *> &Worklist,
+                                            BasicBlock &BB) {
+  for (Instruction &I : BB) {
+    if (!isa<ConvergenceControlInst>(I))
+      continue;
+    for (User *U : I.users()) {
+      Worklist.push_back(cast<Instruction>(U)->getParent());
+    }
+  }
+}
+
 template <>
 const BasicBlock *SSAContext::getDefBlock(const Value *value) const {
   if (const auto *instruction = dyn_cast<Instruction>(value))
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 17cc156846d36..41c4912edf3de 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -201,6 +201,7 @@
 #include "llvm/Transforms/Scalar/ConstantHoisting.h"
 #include "llvm/Transforms/Scalar/ConstraintElimination.h"
 #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
+#include "llvm/Transforms/Scalar/CycleConvergenceExtend.h"
 #include "llvm/Transforms/Scalar/DCE.h"
 #include "llvm/Transforms/Scalar/DFAJumpThreading.h"
 #include "llvm/Transforms/Scalar/DeadStoreElimination.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 3b92823cd283b..75386cd2929bf 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -335,6 +335,7 @@ FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass())
 FUNCTION_PASS("coro-elide", CoroElidePass())
 FUNCTION_PASS("correlated-propagation", CorrelatedValuePropagationPass())
 FUNCTION_PASS("count-visits", CountVisitsPass())
+FUNCTION_PASS("cycle-convergence-extend", CycleConvergenceExtendPass())
 FUNCTION_PASS("dce", DCEPass())
 FUNCTION_PASS("declare-to-assign", llvm::AssignmentTrackingPass())
 FUNCTION_PASS("dfa-jump-threading", DFAJumpThreadingPass())
diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt
index ba09ebf8b04c4..c6fc8e74bcb92 100644
--- a/llvm/lib/Transforms/Scalar/CMakeLists.txt
+++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMScalarOpts
   ConstantHoisting.cpp
   ConstraintElimination.cpp
   CorrelatedValuePropagation.cpp
+  CycleConvergenceExtend.cpp
   DCE.cpp
   DeadStoreElimination.cpp
   DFAJumpThreading.cpp
diff --git a/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
new file mode 100644
index 0000000000000..db8e3942ae68b
--- /dev/null
+++ b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
@@ -0,0 +1,249 @@
+//===- CycleConvergenceExtend.cpp - Extend cycle body for convergence
+//--------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to extend cycles: if a token T defined in a cycle
+// L is used at U outside of L, then the entire cycle nest is modified so that
+// every path P from L to U is included in the body of L, including any sibling
+// cycles whose header lies on P.
+//
+// Input CFG:
+//
+//         +-------------------+
+//         | A: token %a = ... | <+
+//         +-------------------+  |
+//           |                    |
+//           v                    |
+//    +--> +-------------------+  |
+//    |    | B: token %b = ... |  |
+//    +--- +-------------------+  |
+//           |                    |
+//           v                    |
+//         +-------------------+  |
+//         |         C         | -+
+//         +-------------------+
+//           |
+//           v
+//         +-------------------+
+//         |  D: use token %b  |
+//         |     use token %a  |
+//         +-------------------+
+//
+// Both cycles in the above nest need to be extended to contain the respective
+// uses %d1 and %d2. To make this work, the block D needs to be split into two
+// blocks "D1;D2" so that D1 is absorbed by the inner cycle while D2 is absorbed
+// by the outer cycle.
+//
+// Transformed CFG:
+//
+//            +-------------------+
+//            | A: token %a = ... | <-----+
+//            +-------------------+       |
+//              |                         |
+//              v                         |
+//            +-------------------+       |
+//    +-----> | B: token %b = ... | -+    |
+//    |       +-------------------+  |    |
+//    |         |                    |    |
+//    |         v                    |    |
+//    |       +-------------------+  |    |
+//    |    +- |         C         |  |    |
+//    |    |  +-------------------+  |    |
+//    |    |    |                    |    |
+//    |    |    v                    |    |
+//    |    |  +-------------------+  |    |
+//    |    |  | D1: use token %b  |  |    |
+//    |    |  +-------------------+  |    |
+//    |    |    |                    |    |
+//    |    |    v                    |    |
+//    |    |  +-------------------+  |    |
+//    +----+- |       Flow1       | <+    |
+//         |  +-------------------+       |
+//         |    |                         |
+//         |    v                         |
+//         |  +-------------------+       |
+//         |  | D2: use token %a  |       |
+//         |  +-------------------+       |
+//         |    |                         |
+//         |    v                         |
+//         |  +-------------------+       |
+//         +> |       Flow2       | ------+
+//            +-------------------+
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/CycleConvergenceExtend.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+#define DEBUG_TYPE "cycle-convergence-extend"
+
+using namespace llvm;
+
+using BBSetVector = SetVector<BasicBlock *>;
+using ExtensionMap = DenseMap<Cycle *, SmallVector<CallBase *>>;
+// A single BB very rarely defines more than one token.
+using TokenDefsMap = DenseMap<BasicBlock *, SmallVector<CallBase *, 1>>;
+using TokenDefUsesMap = DenseMap<CallBase *, SmallVector<CallBase *>>;
+
+static void updateTokenDefs(TokenDefsMap &TokenDefs, BasicBlock &BB) {
+  TokenDefsMap::mapped_type Defs;
+  for (Instruction &I : BB) {
+    if (isa<ConvergenceControlInst>(I))
+      Defs.push_back(cast<CallBase>(&I));
+  }
+  if (Defs.empty()) {
+    TokenDefs.erase(&BB);
+    return;
+  }
+  TokenDefs.insert_or_assign(&BB, std::move(Defs));
+}
+
+static bool splitForExtension(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
+                              CallBase *TokenUse, TokenDefsMap &TokenDefs) {
+  if (DefCycle->contains(BB))
+    return false;
+  BasicBlock *NewBB = BB->splitBasicBlockBefore(TokenUse->getNextNode(),
+                                                BB->getName() + ".ext");
+  if (Cycle *BBCycle = CI.getCycle(BB))
+    CI.addBlockToCycle(NewBB, BBCycle);
+  updateTokenDefs(TokenDefs, *BB);
+  updateTokenDefs(TokenDefs, *NewBB);
+  return true;
+}
+
+static void locateExtensions(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
+                             TokenDefsMap &TokenDefs,
+                             TokenDefUsesMap &TokenDefUses,
+                             SmallVectorImpl<CallBase *> &ExtPoints) {
+  if (auto Iter = TokenDefs.find(BB); Iter != TokenDefs.end()) {
+    for (CallBase *Def : Iter->second) {
+      for (CallBase *TokenUse : TokenDefUses[Def]) {
+        BasicBlock *BB = TokenUse->getParent();
+        if (splitForExtension(CI, DefCycle, BB, TokenUse, TokenDefs)) {
+          ExtPoints.push_back(TokenUse);
+        }
+      }
+    }
+  }
+}
+
+static void initialize(ExtensionMap &ExtBorder, TokenDefsMap &TokenDefs,
+                       TokenDefUsesMap &TokenDefUses, Function &F,
+                       CycleInfo &CI) {
+  for (BasicBlock &BB : F) {
+    updateTokenDefs(TokenDefs, BB);
+    for (Instruction &I : BB) {
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (auto *TokenDef =
+                cast_or_null<CallBase>(CB->getConvergenceControlToken())) {
+          TokenDefUses[TokenDef].push_back(CB);
+        }
+      }
+    }
+  }
+
+  for (BasicBlock &BB : F) {
+    if (Cycle *DefCycle = CI.getCycle(&BB)) {
+      SmallVector<CallBase *> ExtPoints;
+      locateExtensions(CI, DefCycle, &BB, TokenDefs, TokenDefUses, ExtPoints);
+      if (!ExtPoints.empty()) {
+        auto Success = ExtBorder.try_emplace(DefCycle, std::move(ExtPoints));
+        (void)Success;
+        assert(Success.second);
+      }
+    }
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98006


More information about the llvm-commits mailing list