[llvm] [Convergence] Extend cycles to include outside uses of tokens (PR #98006)
Sameer Sahasrabuddhe via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 9 22:16:23 PDT 2024
https://github.com/ssahasra updated https://github.com/llvm/llvm-project/pull/98006
>From 2ac53642e7b00db27cc2f799ebd633309c68f1be Mon Sep 17 00:00:00 2001
From: Sameer Sahasrabuddhe <sameer.sahasrabuddhe at amd.com>
Date: Mon, 8 Jul 2024 13:49:55 +0530
Subject: [PATCH 1/2] [Convergence] Extend cycles to include token uses outside
Whe a convergence control token T defined at an operation D in a cycle C that is
used by an operation U outside C, the cycle is said to be extended up to U. This
because the use of the convergence control T requires that two threads that
execute U must execute converged dynamic instances of U if and only if they
previously executed converged dynamic instances of D.
For more information including a high-level C-like example, see
https://llvm.org/docs//ConvergentOperations.html
This change introduces a pass that captures this token semantics by literally
extending the cycle C to include every path from C to U.
---
llvm/include/llvm/ADT/GenericCycleImpl.h | 126 +++++-
llvm/include/llvm/ADT/GenericCycleInfo.h | 19 +-
llvm/include/llvm/ADT/GenericSSAContext.h | 4 +
.../Scalar/CycleConvergenceExtend.h | 28 ++
llvm/lib/CodeGen/MachineSSAContext.cpp | 6 +
llvm/lib/IR/SSAContext.cpp | 13 +
llvm/lib/Passes/PassBuilder.cpp | 1 +
llvm/lib/Passes/PassRegistry.def | 1 +
llvm/lib/Transforms/Scalar/CMakeLists.txt | 1 +
.../Scalar/CycleConvergenceExtend.cpp | 249 +++++++++++
.../CycleConvergenceExtend/basic.ll | 405 ++++++++++++++++++
11 files changed, 835 insertions(+), 18 deletions(-)
create mode 100644 llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h
create mode 100644 llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
create mode 100644 llvm/test/Transforms/CycleConvergenceExtend/basic.ll
diff --git a/llvm/include/llvm/ADT/GenericCycleImpl.h b/llvm/include/llvm/ADT/GenericCycleImpl.h
index ab9c421a44693..447151ca33ee0 100644
--- a/llvm/include/llvm/ADT/GenericCycleImpl.h
+++ b/llvm/include/llvm/ADT/GenericCycleImpl.h
@@ -177,26 +177,41 @@ auto GenericCycleInfo<ContextT>::getTopLevelParentCycle(BlockT *Block)
}
template <typename ContextT>
-void GenericCycleInfo<ContextT>::moveTopLevelCycleToNewParent(CycleT *NewParent,
- CycleT *Child) {
- assert((!Child->ParentCycle && !NewParent->ParentCycle) &&
- "NewParent and Child must be both top level cycle!\n");
- auto &CurrentContainer =
- Child->ParentCycle ? Child->ParentCycle->Children : TopLevelCycles;
+void GenericCycleInfo<ContextT>::moveToAdjacentCycle(CycleT *NewParent,
+ CycleT *Child) {
+ auto *OldParent = Child->getParentCycle();
+ assert(!OldParent || OldParent->contains(NewParent));
+
+ // Find the child in its current parent (or toplevel) and move it out of its
+ // container, into the new parent.
+ auto &CurrentContainer = OldParent ? OldParent->Children : TopLevelCycles;
auto Pos = llvm::find_if(CurrentContainer, [=](const auto &Ptr) -> bool {
return Child == Ptr.get();
});
assert(Pos != CurrentContainer.end());
NewParent->Children.push_back(std::move(*Pos));
+ // Pos is empty after moving the child out. So we move the last child into its
+ // place rather than refilling the whole container.
*Pos = std::move(CurrentContainer.back());
CurrentContainer.pop_back();
+
Child->ParentCycle = NewParent;
- NewParent->Blocks.insert(Child->block_begin(), Child->block_end());
+ // Add child blocks to the hierarchy up to the old parent.
+ auto *ParentIter = NewParent;
+ while (ParentIter != OldParent) {
+ ParentIter->Blocks.insert(Child->block_begin(), Child->block_end());
+ ParentIter = ParentIter->getParentCycle();
+ }
- for (auto &It : BlockMapTopLevel)
- if (It.second == Child)
- It.second = NewParent;
+ // If Child was a top-level cycle, update the map.
+ if (!OldParent) {
+ auto *H = NewParent->getHeader();
+ auto *NewTLC = getTopLevelParentCycle(H);
+ for (auto &It : BlockMapTopLevel)
+ if (It.second == Child)
+ It.second = NewTLC;
+ }
}
template <typename ContextT>
@@ -286,7 +301,7 @@ void GenericCycleInfoCompute<ContextT>::run(BlockT *EntryBlock) {
<< "discovered child cycle "
<< Info.Context.print(BlockParent->getHeader()) << "\n");
// Make BlockParent the child of NewCycle.
- Info.moveTopLevelCycleToNewParent(NewCycle.get(), BlockParent);
+ Info.moveToAdjacentCycle(NewCycle.get(), BlockParent);
for (auto *ChildEntry : BlockParent->entries())
ProcessPredecessors(ChildEntry);
@@ -409,6 +424,95 @@ void GenericCycleInfo<ContextT>::splitCriticalEdge(BlockT *Pred, BlockT *Succ,
assert(validateTree());
}
+/// \brief Extend a cycle minimally such that it contains every path from that
+/// cycle reaching a a given block.
+///
+/// The cycle structure is updated such that all predecessors of \p toBlock will
+/// be contained (possibly indirectly) in \p cycleToExtend, without removing any
+/// cycles.
+///
+/// If \p transferredBlocks is non-null, all blocks whose direct containing
+/// cycle was changed are appended to the vector.
+template <typename ContextT>
+void GenericCycleInfo<ContextT>::extendCycle(
+ CycleT *cycleToExtend, BlockT *toBlock,
+ SmallVectorImpl<BlockT *> *transferredBlocks) {
+ SmallVector<BlockT *> workList;
+ workList.push_back(toBlock);
+
+ assert(cycleToExtend);
+ while (!workList.empty()) {
+ BlockT *block = workList.pop_back_val();
+ CycleT *cycle = getCycle(block);
+ if (cycleToExtend->contains(cycle))
+ continue;
+
+ auto cycleToInclude = findLargestDisjointAncestor(cycle, cycleToExtend);
+ if (cycleToInclude) {
+ // Move cycle into cycleToExtend.
+ moveToAdjacentCycle(cycleToExtend, cycleToInclude);
+ assert(cycleToInclude->Depth <= cycleToExtend->Depth);
+ GenericCycleInfoCompute<ContextT>::updateDepth(cycleToInclude);
+
+ // Continue from the entries of the newly included cycle.
+ for (BlockT *entry : cycleToInclude->Entries)
+ llvm::append_range(workList, predecessors(entry));
+ } else {
+ // Block is contained in an ancestor of cycleToExtend, just add it
+ // to the cycle and proceed.
+ BlockMap[block] = cycleToExtend;
+ if (transferredBlocks)
+ transferredBlocks->push_back(block);
+
+ CycleT *ancestor = cycleToExtend;
+ do {
+ ancestor->Blocks.insert(block);
+ ancestor = ancestor->getParentCycle();
+ } while (ancestor != cycle);
+
+ llvm::append_range(workList, predecessors(block));
+ }
+ }
+
+ assert(validateTree());
+}
+
+/// \brief Finds the largest ancestor of \p A that is disjoint from \B.
+///
+/// The caller must ensure that \p B does not contain \p A. If \p A
+/// contains \p B, null is returned.
+template <typename ContextT>
+auto GenericCycleInfo<ContextT>::findLargestDisjointAncestor(
+ const CycleT *A, const CycleT *B) const -> CycleT * {
+ if (!A || !B)
+ return nullptr;
+
+ while (B && A->Depth < B->Depth)
+ B = B->ParentCycle;
+ while (A && A->Depth > B->Depth)
+ A = A->ParentCycle;
+
+ if (A == B)
+ return nullptr;
+
+ assert(A && B);
+ assert(A->Depth == B->Depth);
+
+ for (;;) {
+ // Since both are at the same depth, the only way for both A and B to be
+ // null is when their parents are null, which will terminate the loop.
+ assert(A && B);
+
+ if (A->ParentCycle == B->ParentCycle) {
+ // const_cast is justified since cycles are owned by this
+ // object, which is non-const.
+ return const_cast<CycleT *>(A);
+ }
+ A = A->ParentCycle;
+ B = B->ParentCycle;
+ }
+}
+
/// \brief Find the innermost cycle containing a given block.
///
/// \returns the innermost cycle containing \p Block or nullptr if
diff --git a/llvm/include/llvm/ADT/GenericCycleInfo.h b/llvm/include/llvm/ADT/GenericCycleInfo.h
index b601fc9bae38a..fd68bfe40ce64 100644
--- a/llvm/include/llvm/ADT/GenericCycleInfo.h
+++ b/llvm/include/llvm/ADT/GenericCycleInfo.h
@@ -250,13 +250,7 @@ template <typename ContextT> class GenericCycleInfo {
///
/// Note: This is an incomplete operation that does not update the depth of
/// the subtree.
- void moveTopLevelCycleToNewParent(CycleT *NewParent, CycleT *Child);
-
- /// Assumes that \p Cycle is the innermost cycle containing \p Block.
- /// \p Block will be appended to \p Cycle and all of its parent cycles.
- /// \p Block will be added to BlockMap with \p Cycle and
- /// BlockMapTopLevel with \p Cycle's top level parent cycle.
- void addBlockToCycle(BlockT *Block, CycleT *Cycle);
+ void moveToAdjacentCycle(CycleT *NewParent, CycleT *Child);
public:
GenericCycleInfo() = default;
@@ -275,6 +269,15 @@ template <typename ContextT> class GenericCycleInfo {
unsigned getCycleDepth(const BlockT *Block) const;
CycleT *getTopLevelParentCycle(BlockT *Block);
+ /// Assumes that \p Cycle is the innermost cycle containing \p Block.
+ /// \p Block will be appended to \p Cycle and all of its parent cycles.
+ /// \p Block will be added to BlockMap with \p Cycle and
+ /// BlockMapTopLevel with \p Cycle's top level parent cycle.
+ void addBlockToCycle(BlockT *Block, CycleT *Cycle);
+
+ void extendCycle(CycleT *cycleToExtend, BlockT *toBlock,
+ SmallVectorImpl<BlockT *> *transferredBlocks = nullptr);
+
/// Methods for debug and self-test.
//@{
#ifndef NDEBUG
@@ -285,6 +288,8 @@ template <typename ContextT> class GenericCycleInfo {
Printable print(const CycleT *Cycle) { return Cycle->print(Context); }
//@}
+ CycleT *findLargestDisjointAncestor(const CycleT *a, const CycleT *b) const;
+
/// Iteration over top-level cycles.
//@{
using const_toplevel_iterator_base =
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..480fe1a8f1511 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -93,6 +93,9 @@ template <typename _FunctionT> class GenericSSAContext {
static void appendBlockTerms(SmallVectorImpl<const InstructionT *> &terms,
const BlockT &block);
+ static void appendConvergenceTokenUses(std::vector<BlockT *> &Worklist,
+ BlockT &BB);
+
static bool isConstantOrUndefValuePhi(const InstructionT &Instr);
const BlockT *getDefBlock(ConstValueRefT value) const;
@@ -101,6 +104,7 @@ template <typename _FunctionT> class GenericSSAContext {
Printable print(const InstructionT *inst) const;
Printable print(ConstValueRefT value) const;
};
+
} // namespace llvm
#endif // LLVM_ADT_GENERICSSACONTEXT_H
diff --git a/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h b/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h
new file mode 100644
index 0000000000000..0e39452c3f213
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Scalar/CycleConvergenceExtend.h
@@ -0,0 +1,28 @@
+//===- CycleConvergenceExtend.h - Extend cycles for convergence -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides the interface for the CycleConvergenceExtend pass.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
+#define LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class CycleConvergenceExtendPass
+ : public PassInfoMixin<CycleConvergenceExtendPass> {
+public:
+ PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_SCALAR_CYCLECONVERGENCEEXTEND_H
diff --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index e384187b6e859..200faf5a401dd 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -46,6 +46,12 @@ void MachineSSAContext::appendBlockTerms(
terms.push_back(&T);
}
+template <>
+void MachineSSAContext::appendConvergenceTokenUses(
+ std::vector<MachineBasicBlock *> &Worklist, MachineBasicBlock &BB) {
+ llvm_unreachable("Cycle extensions are not supported in MIR yet.");
+}
+
/// Get the defining block of a value.
template <>
const MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const {
diff --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 220abe3083ebd..3d9fb6d05bc5a 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/ModuleSlotTracker.h"
#include "llvm/Support/raw_ostream.h"
@@ -55,6 +56,18 @@ void SSAContext::appendBlockTerms(SmallVectorImpl<const Instruction *> &terms,
terms.push_back(block.getTerminator());
}
+template <>
+void SSAContext::appendConvergenceTokenUses(std::vector<BasicBlock *> &Worklist,
+ BasicBlock &BB) {
+ for (Instruction &I : BB) {
+ if (!isa<ConvergenceControlInst>(I))
+ continue;
+ for (User *U : I.users()) {
+ Worklist.push_back(cast<Instruction>(U)->getParent());
+ }
+ }
+}
+
template <>
const BasicBlock *SSAContext::getDefBlock(const Value *value) const {
if (const auto *instruction = dyn_cast<Instruction>(value))
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 17cc156846d36..41c4912edf3de 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -201,6 +201,7 @@
#include "llvm/Transforms/Scalar/ConstantHoisting.h"
#include "llvm/Transforms/Scalar/ConstraintElimination.h"
#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
+#include "llvm/Transforms/Scalar/CycleConvergenceExtend.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/DFAJumpThreading.h"
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 3b92823cd283b..75386cd2929bf 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -335,6 +335,7 @@ FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass())
FUNCTION_PASS("coro-elide", CoroElidePass())
FUNCTION_PASS("correlated-propagation", CorrelatedValuePropagationPass())
FUNCTION_PASS("count-visits", CountVisitsPass())
+FUNCTION_PASS("cycle-convergence-extend", CycleConvergenceExtendPass())
FUNCTION_PASS("dce", DCEPass())
FUNCTION_PASS("declare-to-assign", llvm::AssignmentTrackingPass())
FUNCTION_PASS("dfa-jump-threading", DFAJumpThreadingPass())
diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt
index ba09ebf8b04c4..c6fc8e74bcb92 100644
--- a/llvm/lib/Transforms/Scalar/CMakeLists.txt
+++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMScalarOpts
ConstantHoisting.cpp
ConstraintElimination.cpp
CorrelatedValuePropagation.cpp
+ CycleConvergenceExtend.cpp
DCE.cpp
DeadStoreElimination.cpp
DFAJumpThreading.cpp
diff --git a/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
new file mode 100644
index 0000000000000..db8e3942ae68b
--- /dev/null
+++ b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
@@ -0,0 +1,249 @@
+//===- CycleConvergenceExtend.cpp - Extend cycle body for convergence
+//--------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to extend cycles: if a token T defined in a cycle
+// L is used at U outside of L, then the entire cycle nest is modified so that
+// every path P from L to U is included in the body of L, including any sibling
+// cycles whose header lies on P.
+//
+// Input CFG:
+//
+// +-------------------+
+// | A: token %a = ... | <+
+// +-------------------+ |
+// | |
+// v |
+// +--> +-------------------+ |
+// | | B: token %b = ... | |
+// +--- +-------------------+ |
+// | |
+// v |
+// +-------------------+ |
+// | C | -+
+// +-------------------+
+// |
+// v
+// +-------------------+
+// | D: use token %b |
+// | use token %a |
+// +-------------------+
+//
+// Both cycles in the above nest need to be extended to contain the respective
+// uses %d1 and %d2. To make this work, the block D needs to be split into two
+// blocks "D1;D2" so that D1 is absorbed by the inner cycle while D2 is absorbed
+// by the outer cycle.
+//
+// Transformed CFG:
+//
+// +-------------------+
+// | A: token %a = ... | <-----+
+// +-------------------+ |
+// | |
+// v |
+// +-------------------+ |
+// +-----> | B: token %b = ... | -+ |
+// | +-------------------+ | |
+// | | | |
+// | v | |
+// | +-------------------+ | |
+// | +- | C | | |
+// | | +-------------------+ | |
+// | | | | |
+// | | v | |
+// | | +-------------------+ | |
+// | | | D1: use token %b | | |
+// | | +-------------------+ | |
+// | | | | |
+// | | v | |
+// | | +-------------------+ | |
+// +----+- | Flow1 | <+ |
+// | +-------------------+ |
+// | | |
+// | v |
+// | +-------------------+ |
+// | | D2: use token %a | |
+// | +-------------------+ |
+// | | |
+// | v |
+// | +-------------------+ |
+// +> | Flow2 | ------+
+// +-------------------+
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/CycleConvergenceExtend.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+#define DEBUG_TYPE "cycle-convergence-extend"
+
+using namespace llvm;
+
+using BBSetVector = SetVector<BasicBlock *>;
+using ExtensionMap = DenseMap<Cycle *, SmallVector<CallBase *>>;
+// A single BB very rarely defines more than one token.
+using TokenDefsMap = DenseMap<BasicBlock *, SmallVector<CallBase *, 1>>;
+using TokenDefUsesMap = DenseMap<CallBase *, SmallVector<CallBase *>>;
+
+static void updateTokenDefs(TokenDefsMap &TokenDefs, BasicBlock &BB) {
+ TokenDefsMap::mapped_type Defs;
+ for (Instruction &I : BB) {
+ if (isa<ConvergenceControlInst>(I))
+ Defs.push_back(cast<CallBase>(&I));
+ }
+ if (Defs.empty()) {
+ TokenDefs.erase(&BB);
+ return;
+ }
+ TokenDefs.insert_or_assign(&BB, std::move(Defs));
+}
+
+static bool splitForExtension(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
+ CallBase *TokenUse, TokenDefsMap &TokenDefs) {
+ if (DefCycle->contains(BB))
+ return false;
+ BasicBlock *NewBB = BB->splitBasicBlockBefore(TokenUse->getNextNode(),
+ BB->getName() + ".ext");
+ if (Cycle *BBCycle = CI.getCycle(BB))
+ CI.addBlockToCycle(NewBB, BBCycle);
+ updateTokenDefs(TokenDefs, *BB);
+ updateTokenDefs(TokenDefs, *NewBB);
+ return true;
+}
+
+static void locateExtensions(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
+ TokenDefsMap &TokenDefs,
+ TokenDefUsesMap &TokenDefUses,
+ SmallVectorImpl<CallBase *> &ExtPoints) {
+ if (auto Iter = TokenDefs.find(BB); Iter != TokenDefs.end()) {
+ for (CallBase *Def : Iter->second) {
+ for (CallBase *TokenUse : TokenDefUses[Def]) {
+ BasicBlock *BB = TokenUse->getParent();
+ if (splitForExtension(CI, DefCycle, BB, TokenUse, TokenDefs)) {
+ ExtPoints.push_back(TokenUse);
+ }
+ }
+ }
+ }
+}
+
+static void initialize(ExtensionMap &ExtBorder, TokenDefsMap &TokenDefs,
+ TokenDefUsesMap &TokenDefUses, Function &F,
+ CycleInfo &CI) {
+ for (BasicBlock &BB : F) {
+ updateTokenDefs(TokenDefs, BB);
+ for (Instruction &I : BB) {
+ if (auto *CB = dyn_cast<CallBase>(&I)) {
+ if (auto *TokenDef =
+ cast_or_null<CallBase>(CB->getConvergenceControlToken())) {
+ TokenDefUses[TokenDef].push_back(CB);
+ }
+ }
+ }
+ }
+
+ for (BasicBlock &BB : F) {
+ if (Cycle *DefCycle = CI.getCycle(&BB)) {
+ SmallVector<CallBase *> ExtPoints;
+ locateExtensions(CI, DefCycle, &BB, TokenDefs, TokenDefUses, ExtPoints);
+ if (!ExtPoints.empty()) {
+ auto Success = ExtBorder.try_emplace(DefCycle, std::move(ExtPoints));
+ (void)Success;
+ assert(Success.second);
+ }
+ }
+ }
+}
+
+static bool hasSuccInsideCycle(BasicBlock *BB, Cycle *C) {
+ for (BasicBlock *Succ : successors(BB)) {
+ if (C->contains(Succ))
+ return true;
+ }
+ return false;
+}
+
+PreservedAnalyses CycleConvergenceExtendPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ LLVM_DEBUG(dbgs() << "===== Cycle convergence extension for function "
+ << F.getName() << "\n");
+
+ DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+ CycleInfo &CI = AM.getResult<CycleAnalysis>(F);
+ ExtensionMap ExtBorder;
+ TokenDefsMap TokenDefs;
+ TokenDefUsesMap TokenDefUses;
+
+ initialize(ExtBorder, TokenDefs, TokenDefUses, F, CI);
+ if (ExtBorder.empty())
+ return PreservedAnalyses::all();
+
+ for (auto Iter : ExtBorder) {
+ Cycle *DefCycle = Iter.first;
+ auto &ExtList = Iter.second;
+ SmallVector<BasicBlock *> TransferredBlocks;
+
+ LLVM_DEBUG(dbgs() << "Extend cycle with header "
+ << DefCycle->getHeader()->getName());
+ assert(!ExtList.empty());
+ for (auto I = ExtList.begin(); I != ExtList.end(); ++I) {
+ CallBase *ExtPoint = *I;
+ if (DefCycle->contains(ExtPoint->getParent()))
+ continue;
+ LLVM_DEBUG(dbgs() << "\n up to " << ExtPoint->getParent()->getName()
+ << "\n for token used: " << *ExtPoint << "\n");
+ CI.extendCycle(DefCycle, ExtPoint->getParent(), &TransferredBlocks);
+ for (BasicBlock *BB : TransferredBlocks) {
+ locateExtensions(CI, DefCycle, BB, TokenDefs, TokenDefUses, ExtList);
+ }
+ };
+
+ LLVM_DEBUG(dbgs() << "After extension:\n" << CI.print(DefCycle) << "\n");
+
+ BBSetVector Incoming, Outgoing;
+ SmallVector<BasicBlock *> GuardBlocks;
+
+ for (CallBase *ExtPoint : ExtList) {
+ BasicBlock *BB = ExtPoint->getParent();
+ if (!hasSuccInsideCycle(BB, DefCycle))
+ Incoming.insert(BB);
+ }
+ for (BasicBlock *BB : Incoming) {
+ for (BasicBlock *Succ : successors(BB)) {
+ if (!DefCycle->contains(Succ))
+ Outgoing.insert(Succ);
+ }
+ }
+
+ // Redirect the backedges as well, just to add non-trivial edges to the ones
+ // being redirected.
+ for (BasicBlock *Pred : predecessors(DefCycle->getHeader())) {
+ if (DefCycle->contains(Pred))
+ Incoming.insert(Pred);
+ }
+ // We don't touch the exiting edges of the latches simply because
+ // redirecting them is not a post-condition of this transform. Separately,
+ // the header must be the last Outgoing block so that the entire chain of
+ // guard blocks is included in the cycle.
+ Outgoing.insert(DefCycle->getHeader());
+
+ CreateControlFlowHub(&DTU, GuardBlocks, Incoming, Outgoing, "Extend");
+ for (BasicBlock *BB : GuardBlocks)
+ CI.addBlockToCycle(BB, DefCycle);
+ }
+
+ PreservedAnalyses PA;
+ return PA;
+}
diff --git a/llvm/test/Transforms/CycleConvergenceExtend/basic.ll b/llvm/test/Transforms/CycleConvergenceExtend/basic.ll
new file mode 100644
index 0000000000000..591e55eaa477f
--- /dev/null
+++ b/llvm/test/Transforms/CycleConvergenceExtend/basic.ll
@@ -0,0 +1,405 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -passes='cycle-convergence-extend' 2>&1 | FileCheck %s
+
+;
+; |
+; A] %a1 = anchor
+; |
+; B
+; |\
+; | C
+; |/ \
+; D |
+; E %e = user (%a1)
+;
+
+define void @extend_loops(i1 %flag1, i1 %flag2, i1 %flag3) {
+; CHECK-LABEL: define void @extend_loops(
+; CHECK-SAME: i1 [[FLAG1:%.*]], i1 [[FLAG2:%.*]], i1 [[FLAG3:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[B:.*]]
+; CHECK: [[B]]:
+; CHECK-NEXT: br i1 [[FLAG2]], label %[[C:.*]], label %[[D:.*]]
+; CHECK: [[C]]:
+; CHECK-NEXT: br i1 [[FLAG3]], label %[[D]], label %[[E_EXT:.*]]
+; CHECK: [[D]]:
+; CHECK-NEXT: ret void
+; CHECK: [[E_EXT]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[E:.*]]:
+; CHECK-NEXT: br label %[[F:.*]]
+; CHECK: [[F]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_E:%.*]] = phi i1 [ true, %[[E_EXT]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: br i1 [[GUARD_E]], label %[[E]], label %[[A1]]
+;
+entry:
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag1, label %A, label %B
+
+B:
+ br i1 %flag2, label %C, label %D
+
+C:
+ br i1 %flag3, label %D, label %E
+
+D:
+ ret void
+
+E:
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %a1) ]
+ br label %F
+
+F:
+ ret void
+}
+
+;
+; |
+; A] %a1 = anchor
+; |
+; B %b1 = anchor
+; / \
+; C \
+; / \ |
+; D | |
+; E | %e = user (%b1)
+; |
+; F %f = user (%a1)
+;
+
+define void @extend_loops_iterate(i1 %flag1, i1 %flag2, i1 %flag3) {
+; CHECK-LABEL: define void @extend_loops_iterate(
+; CHECK-SAME: i1 [[FLAG1:%.*]], i1 [[FLAG2:%.*]], i1 [[FLAG3:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[B1:.*]]
+; CHECK: [[B1]]:
+; CHECK-NEXT: [[B:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG2]], label %[[C:.*]], label %[[F_EXT:.*]]
+; CHECK: [[C]]:
+; CHECK-NEXT: br i1 [[FLAG3]], label %[[D:.*]], label %[[E_EXT:.*]]
+; CHECK: [[D]]:
+; CHECK-NEXT: ret void
+; CHECK: [[E_EXT]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[E:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[F_EXT]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[F:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_F:%.*]] = phi i1 [ true, %[[F_EXT]] ], [ false, %[[E_EXT]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: [[GUARD_E:%.*]] = phi i1 [ false, %[[F_EXT]] ], [ true, %[[E_EXT]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: br i1 [[GUARD_F]], label %[[F]], label %[[EXTEND_GUARD1:.*]]
+; CHECK: [[EXTEND_GUARD1]]:
+; CHECK-NEXT: br i1 [[GUARD_E]], label %[[E]], label %[[A1]]
+;
+entry:
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag1, label %A, label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag2, label %C, label %F
+
+C:
+ br i1 %flag3, label %D, label %E
+
+D:
+ ret void
+
+E:
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ ret void
+
+F:
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %a1) ]
+ ret void
+}
+
+;
+; |
+; A<-\ %a1 = heart
+; | |
+; B] | %b1 = heart (%a1)
+; | |
+; C>-/
+; |
+; D %d1 = user (%b1)
+; %d2 = user (%a1)
+;
+
+define void @nested_loop_extension(i1 %flag1, i1 %flag2) {
+; CHECK-LABEL: define void @nested_loop_extension(
+; CHECK-SAME: i1 [[FLAG1:%.*]], i1 [[FLAG2:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[ANCHOR:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[ANCHOR]]) ]
+; CHECK-NEXT: br label %[[B1:.*]]
+; CHECK: [[B1]]:
+; CHECK-NEXT: [[B:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[C:.*]]
+; CHECK: [[C]]:
+; CHECK-NEXT: br i1 [[FLAG2]], label %[[EXTEND_GUARD1:.*]], label %[[D_EXT_EXT:.*]]
+; CHECK: [[D_EXT_EXT]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[D_EXT:.*]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD1]]
+; CHECK: [[D:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_D_EXT:%.*]] = phi i1 [ true, %[[D_EXT_EXT]] ], [ false, %[[B1]] ]
+; CHECK-NEXT: br i1 [[GUARD_D_EXT]], label %[[D_EXT]], label %[[B1]]
+; CHECK: [[EXTEND_GUARD1]]:
+; CHECK-NEXT: [[GUARD_D:%.*]] = phi i1 [ true, %[[D_EXT]] ], [ false, %[[C]] ]
+; CHECK-NEXT: br i1 [[GUARD_D]], label %[[D]], label %[[A1]]
+;
+entry:
+ %anchor = call token @llvm.experimental.convergence.anchor()
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %anchor) ]
+ br label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %a1) ]
+ br i1 %flag1, label %B, label %C
+
+C:
+ br i1 %flag2, label %A,label %D
+
+D:
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %a1) ]
+ ret void
+}
+
+;
+; |
+; A] %a1 = anchor
+; |
+; B %b1 = anchor <-- should be associated to extended cycle!
+; |\ %b2 = user (%b1)
+; | X
+; C %c = user (%a1)
+;
+
+define void @multi_block_extension(i1 %flag1, i1 %flag2) {
+; CHECK-LABEL: define void @multi_block_extension(
+; CHECK-SAME: i1 [[FLAG1:%.*]], i1 [[FLAG2:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[B:.*]]
+; CHECK: [[B]]:
+; CHECK-NEXT: [[B1:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B1]]) ]
+; CHECK-NEXT: br i1 [[FLAG2]], label %[[X:.*]], label %[[C_EXT:.*]]
+; CHECK: [[X]]:
+; CHECK-NEXT: ret void
+; CHECK: [[C_EXT]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[C:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_C:%.*]] = phi i1 [ true, %[[C_EXT]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: br i1 [[GUARD_C]], label %[[C]], label %[[A1]]
+;
+entry:
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag1, label %A, label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.anchor()
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ br i1 %flag2, label %X, label %C
+
+X:
+ ret void
+
+C:
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %a1) ]
+ ret void
+}
+
+;
+; |
+; A] %a1 = anchor
+; |
+; B %b1 = anchor <-- should be associated to extended cycle!
+; %b2 = user (%b1)
+; %b3 = user (%a1)
+;
+
+define void @multi_extension(i1 %flag1) {
+; CHECK-LABEL: define void @multi_extension(
+; CHECK-SAME: i1 [[FLAG1:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[B_EXT:.*]]
+; CHECK: [[B_EXT]]:
+; CHECK-NEXT: [[B1:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B1]]) ]
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[B:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_B:%.*]] = phi i1 [ true, %[[B_EXT]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: br i1 [[GUARD_B]], label %[[B]], label %[[A1]]
+;
+entry:
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag1, label %A, label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.anchor()
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %a1) ]
+ ret void
+}
+
+;
+; |
+; A] %a1 = anchor
+; |
+; B] %b1 = loop heart (%a1)
+; | %b2 = user (%b1)
+; |
+; C
+;
+
+define void @lift_loop(i1 %flag1, i1 %flag2) {
+; CHECK-LABEL: define void @lift_loop(
+; CHECK-SAME: i1 [[FLAG1:%.*]], i1 [[FLAG2:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[A1]], label %[[B_EXT:.*]]
+; CHECK: [[B_EXT]]:
+; CHECK-NEXT: [[B1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[B:.*]]
+; CHECK: [[B]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B1]]) ]
+; CHECK-NEXT: br i1 [[FLAG2]], label %[[B_EXT]], label %[[C:.*]]
+; CHECK: [[C]]:
+; CHECK-NEXT: ret void
+;
+entry:
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.anchor()
+ br i1 %flag1, label %A, label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %a1) ]
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ br i1 %flag2, label %B, label %C
+
+C:
+ ret void
+}
+
+define void @false_heart_trivial() convergent {
+; CHECK-LABEL: define void @false_heart_trivial(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT: br label %[[NEXT:.*]]
+; CHECK: [[NEXT]]:
+; CHECK-NEXT: [[B:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B]]) ]
+; CHECK-NEXT: ret void
+;
+entry:
+ %a1 = call token @llvm.experimental.convergence.entry()
+ br label %next
+
+next:
+ %b1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %a1) ]
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ ret void
+}
+
+;
+; |
+; A] %a1 = loop heart
+; |
+; B %b1 = false heart (%a1)
+; %b2 = user (%b1)
+;
+
+define void @false_heart_lifted(i1 %flag1) {
+; CHECK-LABEL: define void @false_heart_lifted(
+; CHECK-SAME: i1 [[FLAG1:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[ANCHOR:%.*]] = call token @llvm.experimental.convergence.anchor()
+; CHECK-NEXT: br label %[[A1:.*]]
+; CHECK: [[A1]]:
+; CHECK-NEXT: [[A:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[ANCHOR]]) ]
+; CHECK-NEXT: br i1 [[FLAG1]], label %[[EXTEND_GUARD:.*]], label %[[B_EXT:.*]]
+; CHECK: [[B_EXT]]:
+; CHECK-NEXT: [[B1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[A]]) ]
+; CHECK-NEXT: br label %[[B_EXT1:.*]]
+; CHECK: [[B_EXT1]]:
+; CHECK-NEXT: call void @convergent.op(i32 0) [ "convergencectrl"(token [[B1]]) ]
+; CHECK-NEXT: br label %[[EXTEND_GUARD]]
+; CHECK: [[B:.*]]:
+; CHECK-NEXT: ret void
+; CHECK: [[EXTEND_GUARD]]:
+; CHECK-NEXT: [[GUARD_B:%.*]] = phi i1 [ true, %[[B_EXT1]] ], [ false, %[[A1]] ]
+; CHECK-NEXT: br i1 [[GUARD_B]], label %[[B]], label %[[A1]]
+;
+entry:
+ %anchor = call token @llvm.experimental.convergence.anchor()
+ br label %A
+
+A:
+ %a1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %anchor) ]
+ br i1 %flag1, label %A, label %B
+
+B:
+ %b1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %a1) ]
+ call void @convergent.op(i32 0) [ "convergencectrl"(token %b1) ]
+ ret void
+}
+
+declare void @convergent.op(i32) convergent
+
+declare token @llvm.experimental.convergence.entry()
+declare token @llvm.experimental.convergence.anchor()
+declare token @llvm.experimental.convergence.loop()
>From a454225e8a6e4eb1a2f82138cfb14c53bd65364f Mon Sep 17 00:00:00 2001
From: Sameer Sahasrabuddhe <sameer.sahasrabuddhe at amd.com>
Date: Tue, 9 Jul 2024 13:47:53 +0530
Subject: [PATCH 2/2] Incremental review change:
- Actually update the DomTree.
- Declare that DomTree and CycleInfo are preserved.
- Additional explanatory note at the top of the file.
---
.../Scalar/CycleConvergenceExtend.cpp | 43 ++++++++++++++++---
1 file changed, 37 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
index db8e3942ae68b..a0546c86a2c12 100644
--- a/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
+++ b/llvm/lib/Transforms/Scalar/CycleConvergenceExtend.cpp
@@ -6,6 +6,19 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+// NOTE: It is not clear if the effects of this transform can survive other
+// control flow transforms such as jump-threading. Whether or not every such
+// transform can preserve this CFG, and even if it can, whether that transform
+// should preserve this CFG has not been determined yet.
+//
+// For now, this transform is meant to be used as late as possible, when
+// preparing the CFG for code generation on targets that support convergence
+// control tokens, such as AMDGPU. It is possible that the transform may
+// eventually be merged into the structurizer or similar passes.
+//
+// But notably, this transform serves as a good WYSIWYM demonstration of
+// convergence control tokens.
+// ===----------------------------------------------------------------------===//
//
// This file implements a pass to extend cycles: if a token T defined in a cycle
// L is used at U outside of L, then the entire cycle nest is modified so that
@@ -110,11 +123,13 @@ static void updateTokenDefs(TokenDefsMap &TokenDefs, BasicBlock &BB) {
}
static bool splitForExtension(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
- CallBase *TokenUse, TokenDefsMap &TokenDefs) {
+ CallBase *TokenUse, TokenDefsMap &TokenDefs,
+ DomTreeUpdater &DTU) {
if (DefCycle->contains(BB))
return false;
BasicBlock *NewBB = BB->splitBasicBlockBefore(TokenUse->getNextNode(),
BB->getName() + ".ext");
+ DTU.getDomTree().splitBlock(NewBB);
if (Cycle *BBCycle = CI.getCycle(BB))
CI.addBlockToCycle(NewBB, BBCycle);
updateTokenDefs(TokenDefs, *BB);
@@ -125,12 +140,13 @@ static bool splitForExtension(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
static void locateExtensions(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
TokenDefsMap &TokenDefs,
TokenDefUsesMap &TokenDefUses,
+ DomTreeUpdater &DTU,
SmallVectorImpl<CallBase *> &ExtPoints) {
if (auto Iter = TokenDefs.find(BB); Iter != TokenDefs.end()) {
for (CallBase *Def : Iter->second) {
for (CallBase *TokenUse : TokenDefUses[Def]) {
BasicBlock *BB = TokenUse->getParent();
- if (splitForExtension(CI, DefCycle, BB, TokenUse, TokenDefs)) {
+ if (splitForExtension(CI, DefCycle, BB, TokenUse, TokenDefs, DTU)) {
ExtPoints.push_back(TokenUse);
}
}
@@ -140,7 +156,7 @@ static void locateExtensions(CycleInfo &CI, Cycle *DefCycle, BasicBlock *BB,
static void initialize(ExtensionMap &ExtBorder, TokenDefsMap &TokenDefs,
TokenDefUsesMap &TokenDefUses, Function &F,
- CycleInfo &CI) {
+ CycleInfo &CI, DomTreeUpdater &DTU) {
for (BasicBlock &BB : F) {
updateTokenDefs(TokenDefs, BB);
for (Instruction &I : BB) {
@@ -156,7 +172,7 @@ static void initialize(ExtensionMap &ExtBorder, TokenDefsMap &TokenDefs,
for (BasicBlock &BB : F) {
if (Cycle *DefCycle = CI.getCycle(&BB)) {
SmallVector<CallBase *> ExtPoints;
- locateExtensions(CI, DefCycle, &BB, TokenDefs, TokenDefUses, ExtPoints);
+ locateExtensions(CI, DefCycle, &BB, TokenDefs, TokenDefUses, DTU, ExtPoints);
if (!ExtPoints.empty()) {
auto Success = ExtBorder.try_emplace(DefCycle, std::move(ExtPoints));
(void)Success;
@@ -186,7 +202,7 @@ PreservedAnalyses CycleConvergenceExtendPass::run(Function &F,
TokenDefsMap TokenDefs;
TokenDefUsesMap TokenDefUses;
- initialize(ExtBorder, TokenDefs, TokenDefUses, F, CI);
+ initialize(ExtBorder, TokenDefs, TokenDefUses, F, CI, DTU);
if (ExtBorder.empty())
return PreservedAnalyses::all();
@@ -206,12 +222,15 @@ PreservedAnalyses CycleConvergenceExtendPass::run(Function &F,
<< "\n for token used: " << *ExtPoint << "\n");
CI.extendCycle(DefCycle, ExtPoint->getParent(), &TransferredBlocks);
for (BasicBlock *BB : TransferredBlocks) {
- locateExtensions(CI, DefCycle, BB, TokenDefs, TokenDefUses, ExtList);
+ locateExtensions(CI, DefCycle, BB, TokenDefs, TokenDefUses, DTU, ExtList);
}
};
LLVM_DEBUG(dbgs() << "After extension:\n" << CI.print(DefCycle) << "\n");
+ // Now that we have absorbed the convergence extensions into the cycle, we
+ // need to introduce dummy backedges so that the cycle remains strongly
+ // connected.
BBSetVector Incoming, Outgoing;
SmallVector<BasicBlock *> GuardBlocks;
@@ -242,8 +261,20 @@ PreservedAnalyses CycleConvergenceExtendPass::run(Function &F,
CreateControlFlowHub(&DTU, GuardBlocks, Incoming, Outgoing, "Extend");
for (BasicBlock *BB : GuardBlocks)
CI.addBlockToCycle(BB, DefCycle);
+ DTU.flush();
}
+#if !defined(NDEBUG)
+#if defined(EXPENSIVE_CHECKS)
+ assert(DT.verify(DominatorTree::VerificationLevel::Full));
+#else
+ assert(DT.verify(DominatorTree::VerificationLevel::Fast));
+#endif // EXPENSIVE_CHECKS
+ CI.validateTree();
+#endif // NDEBUG
+
PreservedAnalyses PA;
+ PA.preserve<DominatorTreeAnalysis>();
+ PA.preserve<CycleAnalysis>();
return PA;
}
More information about the llvm-commits
mailing list