[llvm] [SPIR-V] add convergence region analysis (PR #78456)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 22 08:37:25 PST 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/78456

>From 6542312a28d9aed6e4883608aab09df3afaef838 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 2 Jan 2024 12:28:51 +0100
Subject: [PATCH 1/4] [SPIR-V] add convergence region analysis

This new analysis returns a hierarchical view of the convergence
regions in the given function.
This will allow our passes to query which basic block belongs to
which convergence region, and structurize the code in consequence.

Definition
----------

A convergence region is a CFG with:
 - a single entry node.
 - one or multiple exit nodes (different from LLVM's regions).
 - one back-edge
 - zero or more subregions.

Excluding sub-regions nodes, the nodes of a region can only reference a
single convergence token. A subregion uses a different convergence
token.

Algorithm
---------

This algorithm assumes all loops are in the Simplify form.

Create an initial convergence region for the whole function.
  - the convergence token is the function entry token.
  - the entry is the function entrypoint.
  - Exits are all the basic blocks terminating with a return instruction.

Take the function CFG, and process it in DAG order (ignoring back-edges).
If a basic block is a loop header:
 - Create a new region.
   - The parent region is the parent's loop region if any, otherwise, the
     top level region.
   - The region blocks are all the blocks belonging to this loop.
   - For each loop exit:
        - visit the rest of the CFG in DAG order (ignore back-edges).
        - if the region's convergence token is found, add all the blocks
          dominated by the exit from which the token is reachable to
          the region.
   - continue the algorithm with the loop headers successors.
---
 llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt |   10 +
 .../Analysis/ConvergenceRegionAnalysis.cpp    |  310 +++++
 .../Analysis/ConvergenceRegionAnalysis.h      |  173 +++
 llvm/lib/Target/SPIRV/CMakeLists.txt          |    2 +
 llvm/unittests/Target/SPIRV/CMakeLists.txt    |   17 +
 .../SPIRV/ConvergenceRegionAnalysisTests.cpp  | 1015 +++++++++++++++++
 6 files changed, 1527 insertions(+)
 create mode 100644 llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
 create mode 100644 llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp
 create mode 100644 llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h
 create mode 100644 llvm/unittests/Target/SPIRV/CMakeLists.txt
 create mode 100644 llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp

diff --git a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
new file mode 100644
index 00000000000000..374aee3ed1c766
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_component_library(LLVMSPIRVAnalysis
+  ConvergenceRegionAnalysis.cpp
+
+  LINK_COMPONENTS
+  Core
+  Support
+
+  ADD_TO_COMPONENT
+  SPIRV
+  )
diff --git a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp
new file mode 100644
index 00000000000000..5102bc2d4228cc
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp
@@ -0,0 +1,310 @@
+//===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The analysis determines the convergence region for each basic block of
+// the module, and provides a tree-like structure describing the region
+// hierarchy.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ConvergenceRegionAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include <optional>
+#include <queue>
+
+namespace llvm {
+namespace SPIRV {
+
+namespace {
+
+template <typename BasicBlockType, typename IntrinsicInstType>
+std::optional<IntrinsicInstType *>
+getConvergenceTokenInternal(BasicBlockType *BB) {
+  static_assert(std::is_const_v<IntrinsicInstType> ==
+                    std::is_const_v<BasicBlockType>,
+                "Constness must match between input and output.");
+  static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
+                "Input must be a basic block.");
+  static_assert(
+      std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
+      "Output type must be an intrinsic instruction.");
+
+  for (auto &I : *BB) {
+    if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+      if (II->getIntrinsicID() != Intrinsic::experimental_convergence_entry &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_loop &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_anchor) {
+        continue;
+      }
+
+      if (II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
+          II->getIntrinsicID() == Intrinsic::experimental_convergence_loop) {
+        return II;
+      }
+
+      auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
+      assert(Bundle->Inputs.size() == 1 &&
+             Bundle->Inputs[0]->getType()->isTokenTy());
+      auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
+      ;
+      assert(TII != nullptr);
+      return TII;
+    }
+
+    if (auto *CI = dyn_cast<CallInst>(&I)) {
+      auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
+      if (!OB.has_value())
+        continue;
+      return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
+    }
+  }
+
+  return std::nullopt;
+}
+
+} // anonymous namespace
+
+std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
+  return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
+}
+
+std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
+  return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
+}
+
+ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
+                                     Function &F)
+    : DT(DT), LI(LI), Parent(nullptr) {
+  Entry = &F.getEntryBlock();
+  ConvergenceToken = getConvergenceToken(Entry);
+  for (auto &B : F) {
+    Blocks.insert(&B);
+    if (isa<ReturnInst>(B.getTerminator()))
+      Exits.insert(&B);
+  }
+}
+
+ConvergenceRegion::ConvergenceRegion(
+    DominatorTree &DT, LoopInfo &LI,
+    std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
+    SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
+    : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
+      Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
+  for (auto *BB : this->Exits)
+    assert(this->Blocks.count(BB) != 0);
+  assert(this->Blocks.count(this->Entry) != 0);
+}
+
+void ConvergenceRegion::releaseMemory() {
+  // Parent memory is owned by the parent.
+  Parent = nullptr;
+  for (auto *Child : Children) {
+    Child->releaseMemory();
+    delete Child;
+  }
+  Children.resize(0);
+}
+
+void ConvergenceRegion::dump(const unsigned IndentSize) const {
+  const std::string Indent(IndentSize, '\t');
+  dbgs() << Indent << this << ": {\n";
+  dbgs() << Indent << "	Parent: " << Parent << "\n";
+
+  if (ConvergenceToken.value_or(nullptr)) {
+    dbgs() << Indent
+           << "	ConvergenceToken: " << ConvergenceToken.value()->getName()
+           << "\n";
+  }
+
+  if (Entry->getName() != "")
+    dbgs() << Indent << "	Entry: " << Entry->getName() << "\n";
+  else
+    dbgs() << Indent << "	Entry: " << Entry << "\n";
+
+  dbgs() << Indent << "	Exits: { ";
+  for (const auto &Exit : Exits) {
+    if (Exit->getName() != "")
+      dbgs() << Exit->getName() << ", ";
+    else
+      dbgs() << Exit << ", ";
+  }
+  dbgs() << "	}\n";
+
+  dbgs() << Indent << "	Blocks: { ";
+  for (const auto &Block : Blocks) {
+    if (Block->getName() != "")
+      dbgs() << Block->getName() << ", ";
+    else
+      dbgs() << Block << ", ";
+  }
+  dbgs() << "	}\n";
+
+  dbgs() << Indent << "	Children: {\n";
+  for (const auto Child : Children)
+    Child->dump(IndentSize + 2);
+  dbgs() << Indent << "	}\n";
+
+  dbgs() << Indent << "}\n";
+}
+
+class ConvergenceRegionAnalyzer {
+
+public:
+  ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
+      : DT(DT), LI(LI), F(F) {}
+
+private:
+  bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
+    assert(From != To && "From == To. This is awkward.");
+
+    // We only handle loop in the simplified form. This means:
+    // - a single back-edge, a single latch.
+    // - meaning the back-edge target can only be the loop header.
+    // - meaning the From can only be the loop latch.
+    if (!LI.isLoopHeader(To))
+      return false;
+
+    auto *L = LI.getLoopFor(To);
+    if (L->contains(From) && L->isLoopLatch(From))
+      return true;
+
+    return false;
+  }
+
+  std::unordered_set<BasicBlock *>
+  findPathsToMatch(BasicBlock *From,
+                   std::function<bool(const BasicBlock *)> isMatch) const {
+    std::unordered_set<BasicBlock *> Output;
+
+    if (isMatch(From))
+      Output.insert(From);
+
+    auto *Terminator = From->getTerminator();
+    for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
+      auto *To = Terminator->getSuccessor(i);
+      if (isBackEdge(From, To))
+        continue;
+
+      auto ChildSet = findPathsToMatch(To, isMatch);
+      if (ChildSet.size() == 0)
+        continue;
+
+      Output.insert(ChildSet.begin(), ChildSet.end());
+      Output.insert(From);
+    }
+
+    return Output;
+  }
+
+  SmallPtrSet<BasicBlock *, 2>
+  findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
+    SmallPtrSet<BasicBlock *, 2> Exits;
+
+    for (auto *B : RegionBlocks) {
+      auto *Terminator = B->getTerminator();
+      for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
+        auto *Child = Terminator->getSuccessor(i);
+        if (RegionBlocks.count(Child) == 0)
+          Exits.insert(B);
+      }
+    }
+
+    return Exits;
+  }
+
+public:
+  ConvergenceRegionInfo analyze() {
+    ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
+
+    std::unordered_map<Loop *, ConvergenceRegion *> LoopToRegion;
+    std::queue<Loop *> ToProcess;
+    for (auto *L : LI)
+      ToProcess.push(L);
+
+    while (ToProcess.size() != 0) {
+      auto *L = ToProcess.front();
+      ToProcess.pop();
+      for (auto *Child : *L)
+        ToProcess.push(Child);
+
+      assert(L->isLoopSimplifyForm());
+
+      auto CT = getConvergenceToken(L->getHeader());
+      SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
+                                                L->block_end());
+      SmallVector<BasicBlock *> LoopExits;
+      L->getExitingBlocks(LoopExits);
+      if (CT.has_value()) {
+        for (auto *Exit : LoopExits) {
+          auto N = findPathsToMatch(Exit, [&CT](const BasicBlock *block) {
+            auto Token = getConvergenceToken(block);
+            if (Token == std::nullopt)
+              return false;
+            return Token.value() == CT.value();
+          });
+          RegionBlocks.insert(N.begin(), N.end());
+        }
+      }
+
+      auto RegionExits = findExitNodes(RegionBlocks);
+      ConvergenceRegion *Region = new ConvergenceRegion(
+          DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
+          std::move(RegionExits));
+
+      auto It = LoopToRegion.find(L->getParentLoop());
+      assert(It != LoopToRegion.end() || L->getParentLoop() == nullptr);
+      Region->Parent = It != LoopToRegion.end() ? It->second : TopLevelRegion;
+      Region->Parent->Children.push_back(Region);
+
+      LoopToRegion.emplace(L, Region);
+    }
+
+    return ConvergenceRegionInfo(TopLevelRegion);
+  }
+
+private:
+  DominatorTree &DT;
+  LoopInfo &LI;
+  Function &F;
+};
+
+ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
+                                            LoopInfo &LI) {
+  ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
+  return Analyzer.analyze();
+}
+
+char ConvergenceRegionAnalysisWrapperPass::ID = 0;
+
+ConvergenceRegionAnalysisWrapperPass::ConvergenceRegionAnalysisWrapperPass()
+    : FunctionPass(ID) {}
+
+bool ConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
+  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+  CRI = getConvergenceRegions(F, DT, LI);
+  // Nothing was modified.
+  return false;
+}
+
+ConvergenceRegionAnalysis::Result
+ConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+  Result CRI;
+  auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+  auto &LI = AM.getResult<LoopAnalysis>(F);
+  CRI = getConvergenceRegions(F, DT, LI);
+  return CRI;
+}
+
+AnalysisKey ConvergenceRegionAnalysis::Key;
+
+} // namespace SPIRV
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h
new file mode 100644
index 00000000000000..c8cd1c4cd9ddf7
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h
@@ -0,0 +1,173 @@
+//===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The analysis determines the convergence region for each basic block of
+// the module, and provides a tree-like structure describing the region
+// hierarchy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
+
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include <iostream>
+#include <optional>
+#include <unordered_set>
+
+namespace llvm {
+class SPIRVSubtarget;
+class MachineFunction;
+class MachineModuleInfo;
+
+namespace SPIRV {
+
+// Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise.
+std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB);
+std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB);
+
+// Describes a hierarchy of convergence regions.
+// A convergence region defines a CFG for which the execution flow can diverge
+// starting from the entry block, but should reconverge back before the end of
+// the exit blocks.
+class ConvergenceRegion {
+  DominatorTree &DT;
+  LoopInfo &LI;
+
+public:
+  // The parent region of this region, if any.
+  ConvergenceRegion *Parent = nullptr;
+  // The sub-regions contained in this region, if any.
+  SmallVector<ConvergenceRegion *> Children = {};
+  // The convergence instruction linked to this region, if any.
+  std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt;
+  // The only block with a predecessor outside of this region.
+  BasicBlock *Entry = nullptr;
+  // All the blocks with an edge leaving this convergence region.
+  SmallPtrSet<BasicBlock *, 2> Exits = {};
+  // All the blocks that belongs to this region, including its subregions'.
+  SmallPtrSet<BasicBlock *, 8> Blocks = {};
+
+  // Creates a single convergence region encapsulating the whole function |F|.
+  ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F);
+
+  // Creates a single convergence region defined by entry and exits nodes, a
+  // list of blocks, and possibly a convergence token.
+  ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
+                    std::optional<IntrinsicInst *> ConvergenceToken,
+                    BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks,
+                    SmallPtrSet<BasicBlock *, 2> &&Exits);
+
+  ConvergenceRegion(ConvergenceRegion &&CR)
+      : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)),
+        Children(std::move(CR.Children)),
+        ConvergenceToken(std::move(CR.ConvergenceToken)),
+        Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)),
+        Blocks(std::move(CR.Blocks)) {}
+
+  ConvergenceRegion(const ConvergenceRegion &other) = delete;
+
+  // Returns true if the given basic block belongs to this region, or to one of
+  // its subregion.
+  bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; }
+
+  void releaseMemory();
+
+  // Write to the debug output this region's hierarchy.
+  // |IndentSize| defines the number of tabs to print before any new line.
+  void dump(const unsigned IndentSize = 0) const;
+};
+
+// Holds a ConvergenceRegion hierarchy.
+class ConvergenceRegionInfo {
+  // The convergence region this structure holds.
+  ConvergenceRegion *TopLevelRegion;
+
+public:
+  ConvergenceRegionInfo() : TopLevelRegion(nullptr) {}
+
+  // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is
+  // passed to this object.
+  ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion)
+      : TopLevelRegion(TopLevelRegion) {}
+
+  ~ConvergenceRegionInfo() { releaseMemory(); }
+
+  ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS)
+      : TopLevelRegion(LHS.TopLevelRegion) {
+    if (TopLevelRegion != LHS.TopLevelRegion) {
+      releaseMemory();
+      TopLevelRegion = LHS.TopLevelRegion;
+    }
+    LHS.TopLevelRegion = nullptr;
+  }
+
+  ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) {
+    if (TopLevelRegion != LHS.TopLevelRegion) {
+      releaseMemory();
+      TopLevelRegion = LHS.TopLevelRegion;
+    }
+    LHS.TopLevelRegion = nullptr;
+    return *this;
+  }
+
+  void releaseMemory() {
+    if (TopLevelRegion == nullptr)
+      return;
+
+    TopLevelRegion->releaseMemory();
+    delete TopLevelRegion;
+    TopLevelRegion = nullptr;
+  }
+
+  const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
+};
+
+// Wrapper around the function above to use it with the legacy pass manager.
+class ConvergenceRegionAnalysisWrapperPass : public FunctionPass {
+  ConvergenceRegionInfo CRI;
+
+public:
+  static char ID;
+
+  ConvergenceRegionAnalysisWrapperPass();
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addRequired<DominatorTreeWrapperPass>();
+  };
+
+  bool runOnFunction(Function &F) override;
+
+  ConvergenceRegionInfo &getRegionInfo() { return CRI; }
+  const ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
+};
+
+// Wrapper around the function above to use it with the new pass manager.
+class ConvergenceRegionAnalysis
+    : public AnalysisInfoMixin<ConvergenceRegionAnalysis> {
+  friend AnalysisInfoMixin<ConvergenceRegionAnalysis>;
+  static AnalysisKey Key;
+
+public:
+  using Result = ConvergenceRegionInfo;
+
+  Result run(Function &F, FunctionAnalysisManager &AM);
+};
+
+ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
+                                            LoopInfo &LI);
+
+} // namespace SPIRV
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7d17c307db13a0..76710c44767665 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -44,6 +44,7 @@ add_llvm_target(SPIRVCodeGen
   Core
   Demangle
   GlobalISel
+  SPIRVAnalysis
   MC
   SPIRVDesc
   SPIRVInfo
@@ -59,3 +60,4 @@ add_llvm_target(SPIRVCodeGen
 
 add_subdirectory(MCTargetDesc)
 add_subdirectory(TargetInfo)
+add_subdirectory(Analysis)
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
new file mode 100644
index 00000000000000..8f9b81c759a00e
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -0,0 +1,17 @@
+include_directories(
+  ${LLVM_MAIN_SRC_DIR}/lib/Target/SPIRV
+  ${LLVM_BINARY_DIR}/lib/Target/SPIRV
+  )
+
+set(LLVM_LINK_COMPONENTS
+  AsmParser
+  Core
+  SPIRVCodeGen
+  SPIRVAnalysis
+  Support
+  )
+
+add_llvm_target_unittest(SPIRVTests
+  ConvergenceRegionAnalysisTests.cpp
+  )
+
diff --git a/llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp b/llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp
new file mode 100644
index 00000000000000..2a01b00dad42a8
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp
@@ -0,0 +1,1015 @@
+//===- llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/ConvergenceRegionAnalysis.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <queue>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::SPIRV;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+class ConvergenceRegionAnalysisTest : public testing::Test {
+protected:
+  void SetUp() override {
+    // Required for tests.
+    FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+    // Required for ConvergenceRegionAnalysis.
+    FAM.registerPass([&] { return DominatorTreeAnalysis(); });
+    FAM.registerPass([&] { return LoopAnalysis(); });
+
+    FAM.registerPass([&] { return ConvergenceRegionAnalysis(); });
+  }
+
+  void TearDown() override { M.reset(); }
+
+  ConvergenceRegionAnalysis::Result &runAnalysis(StringRef Assembly) {
+    assert(M == nullptr &&
+           "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
+
+    SMDiagnostic Error;
+    M = parseAssemblyString(Assembly, Error, Context);
+    assert(M && "Bad assembly. Bad test?");
+    auto *F = getFunction();
+
+    ModulePassManager MPM;
+    MPM.run(*M, MAM);
+    return FAM.getResult<ConvergenceRegionAnalysis>(*F);
+  }
+
+  ConvergenceRegionAnalysis::Result &getAnalysis() {
+    assert(M != nullptr && "Has runAnalysis been called before?");
+    return FAM.getResult<ConvergenceRegionAnalysis>(*getFunction());
+  }
+
+  Function *getFunction() const {
+    assert(M != nullptr && "Has runAnalysis been called before?");
+    return M->getFunction("main");
+  }
+
+  const BasicBlock *getBlock(StringRef Name) {
+    assert(M != nullptr && "Has runAnalysis been called before?");
+
+    auto *F = getFunction();
+    for (BasicBlock &BB : *F) {
+      if (BB.getName() == Name)
+        return &BB;
+    }
+
+    ADD_FAILURE() << "Error: Could not locate requested block. Bad test?";
+    return nullptr;
+  }
+
+  const ConvergenceRegion *getRegionWithEntry(StringRef Name) {
+    assert(M != nullptr && "Has runAnalysis been called before?");
+
+    std::queue<const ConvergenceRegion *> ToProcess;
+    ToProcess.push(getAnalysis().getTopLevelRegion());
+
+    while (ToProcess.size() != 0) {
+      auto *R = ToProcess.front();
+      ToProcess.pop();
+      for (auto *Child : R->Children)
+        ToProcess.push(Child);
+
+      if (R->Entry->getName() == Name)
+        return R;
+    }
+
+    ADD_FAILURE() << "Error: Could not locate requested region. Bad test?";
+    return nullptr;
+  }
+
+  void checkRegionBlocks(const ConvergenceRegion *R,
+                         std::initializer_list<const char *> InRegion,
+                         std::initializer_list<const char *> NotInRegion) {
+    for (const char *Name : InRegion) {
+      EXPECT_TRUE(R->contains(getBlock(Name)))
+          << "error: " << Name << " not in region " << R->Entry->getName();
+    }
+
+    for (const char *Name : NotInRegion) {
+      EXPECT_FALSE(R->contains(getBlock(Name)))
+          << "error: " << Name << " in region " << R->Entry->getName();
+    }
+  }
+
+protected:
+  LLVMContext Context;
+  FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
+  std::unique_ptr<Module> M;
+};
+
+MATCHER_P(ContainsBasicBlock, label, "") {
+  for (const auto *bb : arg)
+    if (bb->getName() == label)
+      return true;
+  return false;
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, DefaultRegion) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ret void
+    }
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+
+  EXPECT_EQ(CR->Parent, nullptr);
+  EXPECT_EQ(CR->ConvergenceToken, std::nullopt);
+  EXPECT_EQ(CR->Children.size(), 0u);
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, DefaultRegionWithToken) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+
+  EXPECT_EQ(CR->Parent, nullptr);
+  EXPECT_EQ(CR->Children.size(), 0u);
+  EXPECT_TRUE(CR->ConvergenceToken.has_value());
+  EXPECT_EQ(CR->ConvergenceToken.value()->getIntrinsicID(),
+            Intrinsic::experimental_convergence_entry);
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopOneRegion) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+
+  EXPECT_EQ(CR->Parent, nullptr);
+  EXPECT_EQ(CR->ConvergenceToken.value()->getName(), "t1");
+  EXPECT_TRUE(CR->ConvergenceToken.has_value());
+  EXPECT_EQ(CR->ConvergenceToken.value()->getIntrinsicID(),
+            Intrinsic::experimental_convergence_entry);
+  EXPECT_EQ(CR->Children.size(), 1u);
+}
+
+TEST_F(ConvergenceRegionAnalysisTest,
+       SingleLoopLoopRegionParentsIsTopLevelRegion) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+
+  EXPECT_EQ(CR->Parent, nullptr);
+  EXPECT_EQ(CR->ConvergenceToken.value()->getName(), "t1");
+  EXPECT_EQ(CR->Children[0]->Parent, CR);
+  EXPECT_EQ(CR->Children[0]->ConvergenceToken.value()->getName(), "tl1");
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopExits) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = CR->Children[0];
+
+  EXPECT_EQ(L->Exits.size(), 1ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1"));
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakExits) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %end.loopexit
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    end.loopexit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = CR->Children[0];
+
+  EXPECT_EQ(L->Exits.size(), 2ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_condition_true"));
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_header")));
+  EXPECT_TRUE(CR->contains(getBlock("l1_condition_true")));
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakRegionBlocks) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %end.loopexit
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    end.loopexit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  const auto *CR = runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = CR->Children[0];
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_header")));
+  EXPECT_TRUE(L->contains(getBlock("l1_header")));
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_body")));
+  EXPECT_TRUE(L->contains(getBlock("l1_body")));
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_condition_true")));
+  EXPECT_TRUE(L->contains(getBlock("l1_condition_true")));
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_condition_false")));
+  EXPECT_TRUE(L->contains(getBlock("l1_condition_false")));
+
+  EXPECT_TRUE(CR->contains(getBlock("l1_continue")));
+  EXPECT_TRUE(L->contains(getBlock("l1_continue")));
+
+  EXPECT_TRUE(CR->contains(getBlock("end.loopexit")));
+  EXPECT_FALSE(L->contains(getBlock("end.loopexit")));
+
+  EXPECT_TRUE(CR->contains(getBlock("end")));
+  EXPECT_FALSE(L->contains(getBlock("end")));
+}
+
+// Exact same test as before, except the 'if() break' condition in the loop is
+// not marked with any convergence intrinsic. In such case, it is valid to
+// consider it outside of the loop.
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakNoConvergenceControl) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %end.loopexit
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    end.loopexit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  runAnalysis(Assembly);
+  const auto *L = getRegionWithEntry("l1_header");
+
+  EXPECT_EQ(L->Entry->getName(), "l1_header");
+  EXPECT_EQ(L->Exits.size(), 2ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_body"));
+
+  EXPECT_TRUE(L->contains(getBlock("l1_header")));
+  EXPECT_TRUE(L->contains(getBlock("l1_body")));
+  EXPECT_FALSE(L->contains(getBlock("l1_condition_true")));
+  EXPECT_TRUE(L->contains(getBlock("l1_condition_false")));
+  EXPECT_TRUE(L->contains(getBlock("l1_continue")));
+  EXPECT_FALSE(L->contains(getBlock("end.loopexit")));
+  EXPECT_FALSE(L->contains(getBlock("end")));
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, TwoLoopsWithControl) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_exit
+
+    l1_body:
+      br i1 %1, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      br label %mid
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    l1_exit:
+      br label %mid
+
+    mid:
+      br label %l2_header
+
+    l2_header:
+      %tl2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l2_body, label %l2_exit
+
+    l2_body:
+      br i1 %1, label %l2_condition_true, label %l2_condition_false
+
+    l2_condition_true:
+      br label %end
+
+    l2_condition_false:
+      br label %l2_continue
+
+    l2_continue:
+      br label %l2_header
+
+    l2_exit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  runAnalysis(Assembly);
+
+  {
+    const auto *L = getRegionWithEntry("l1_header");
+
+    EXPECT_EQ(L->Entry->getName(), "l1_header");
+    EXPECT_EQ(L->Exits.size(), 2ul);
+    EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+    EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_body"));
+
+    checkRegionBlocks(
+        L, {"l1_header", "l1_body", "l1_condition_false", "l1_continue"},
+        {"", "l2_header", "l2_body", "l2_condition_true", "l2_condition_false",
+         "l2_continue", "l2_exit", "l1_condition_true", "l1_exit", "end"});
+  }
+  {
+    const auto *L = getRegionWithEntry("l2_header");
+
+    EXPECT_EQ(L->Entry->getName(), "l2_header");
+    EXPECT_EQ(L->Exits.size(), 2ul);
+    EXPECT_THAT(L->Exits, ContainsBasicBlock("l2_header"));
+    EXPECT_THAT(L->Exits, ContainsBasicBlock("l2_body"));
+
+    checkRegionBlocks(
+        L, {"l2_header", "l2_body", "l2_condition_false", "l2_continue"},
+        {"", "l1_header", "l1_body", "l1_condition_true", "l1_condition_false",
+         "l1_continue", "l1_exit", "l2_condition_true", "l2_exit", "end"});
+  }
+}
+
+// Both branches in the loop condition break. This means the loop continue
+// targets are unreachable, meaning no reachable back-edge. This should
+// transform the loop condition into a simple condition, meaning we have a
+// single convergence region.
+TEST_F(ConvergenceRegionAnalysisTest, LoopBothBranchExits) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_exit
+
+    l1_body:
+      br i1 %1, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      %call_true = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l1_condition_false:
+      %call_false = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l1_continue:
+      br label %l1_header
+
+    l1_exit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  ;
+  const auto *R = runAnalysis(Assembly).getTopLevelRegion();
+
+  ASSERT_EQ(R->Children.size(), 0ul);
+  EXPECT_EQ(R->Exits.size(), 1ul);
+  EXPECT_THAT(R->Exits, ContainsBasicBlock("end"));
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, InnerLoopBreaks) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_exit
+
+    l1_body:
+      br label %l2_header
+
+    l2_header:
+      %tl2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %tl1) ]
+      br i1 %1, label %l2_body, label %l2_exit
+
+    l2_body:
+      br i1 %1, label %l2_condition_true, label %l2_condition_false
+
+    l2_condition_true:
+      %call_true = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l2_condition_false:
+      br label %l2_continue
+
+    l2_continue:
+      br label %l2_header
+
+    l2_exit:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    l1_exit:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  const auto *R = runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L1 = getRegionWithEntry("l1_header");
+  const auto *L2 = getRegionWithEntry("l2_header");
+
+  EXPECT_EQ(R->Children.size(), 1ul);
+  EXPECT_EQ(L1->Children.size(), 1ul);
+  EXPECT_EQ(L1->Parent, R);
+  EXPECT_EQ(L2->Parent, L1);
+
+  EXPECT_EQ(R->Entry->getName(), "");
+  EXPECT_EQ(R->Exits.size(), 1ul);
+  EXPECT_THAT(R->Exits, ContainsBasicBlock("end"));
+
+  EXPECT_EQ(L1->Entry->getName(), "l1_header");
+  EXPECT_EQ(L1->Exits.size(), 2ul);
+  EXPECT_THAT(L1->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L1->Exits, ContainsBasicBlock("l2_condition_true"));
+
+  checkRegionBlocks(L1,
+                    {"l1_header", "l1_body", "l2_header", "l2_body",
+                     "l2_condition_false", "l2_condition_true", "l2_continue",
+                     "l2_exit", "l1_continue"},
+                    {"", "l1_exit", "end"});
+
+  EXPECT_EQ(L2->Entry->getName(), "l2_header");
+  EXPECT_EQ(L2->Exits.size(), 2ul);
+  EXPECT_THAT(L2->Exits, ContainsBasicBlock("l2_header"));
+  EXPECT_THAT(L2->Exits, ContainsBasicBlock("l2_body"));
+  checkRegionBlocks(
+      L2, {"l2_header", "l2_body", "l2_condition_false", "l2_continue"},
+      {"", "l1_header", "l1_body", "l2_exit", "l1_continue",
+       "l2_condition_true", "l1_exit", "end"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopMultipleExits) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %cond = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %cond, label %l1_body, label %l1_exit
+
+    l1_body:
+      switch i32 0, label %sw.default.exit [
+        i32 0, label %sw.bb
+        i32 1, label %sw.bb1
+        i32 2, label %sw.bb2
+      ]
+
+    sw.default.exit:
+      br label %sw.default
+
+    sw.default:
+      br label %l1_end
+
+    sw.bb:
+      br label %l1_end
+
+    sw.bb1:
+      br label %l1_continue
+
+    sw.bb2:
+      br label %sw.default
+
+    l1_continue:
+      br label %l1
+
+    l1_exit:
+      br label %l1_end
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+  )";
+
+  runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = getRegionWithEntry("l1");
+  ASSERT_NE(L, nullptr);
+
+  EXPECT_EQ(L->Entry, getBlock("l1"));
+  EXPECT_EQ(L->Exits.size(), 2ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_body"));
+
+  checkRegionBlocks(L, {"l1", "l1_body", "l1_continue", "sw.bb1"},
+                    {"", "sw.default.exit", "sw.default", "l1_end", "end",
+                     "sw.bb", "sw.bb2", "l1_exit"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest,
+       SingleLoopMultipleExitsWithPartialConvergence) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %cond = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %cond, label %l1_body, label %l1_exit
+
+    l1_body:
+      switch i32 0, label %sw.default.exit [
+        i32 0, label %sw.bb
+        i32 1, label %sw.bb1
+        i32 2, label %sw.bb2
+      ]
+
+    sw.default.exit:
+      br label %sw.default
+
+    sw.default:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %l1_end
+
+    sw.bb:
+      br label %l1_end
+
+    sw.bb1:
+      br label %l1_continue
+
+    sw.bb2:
+      br label %sw.default
+
+    l1_continue:
+      br label %l1
+
+    l1_exit:
+      br label %l1_end
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = getRegionWithEntry("l1");
+  ASSERT_NE(L, nullptr);
+
+  EXPECT_EQ(L->Entry, getBlock("l1"));
+  EXPECT_EQ(L->Exits.size(), 3ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_body"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("sw.default"));
+
+  checkRegionBlocks(L,
+                    {"l1", "l1_body", "l1_continue", "sw.bb1",
+                     "sw.default.exit", "sw.bb2", "sw.default"},
+                    {"", "l1_end", "end", "sw.bb", "l1_exit"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceBranch) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      br label %a
+
+    a:
+      br label %b
+
+    b:
+      br label %c
+
+    c:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = getRegionWithEntry("l1_header");
+  ASSERT_NE(L, nullptr);
+
+  EXPECT_EQ(L->Entry, getBlock("l1_header"));
+  EXPECT_EQ(L->Exits.size(), 2ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("c"));
+
+  checkRegionBlocks(L,
+                    {"l1_header", "l1_body", "l1_continue",
+                     "l1_condition_false", "l1_condition_true", "a", "b", "c"},
+                    {"", "l1_end", "end"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceLateBranch) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      br label %a
+
+    a:
+      br label %b
+
+    b:
+      br i1 %2, label %c, label %d
+
+    c:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %end
+
+    d:
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+
+    ; This intrinsic is not convergent. This is only because the backend doesn't
+    ; support convergent operations yet.
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = getRegionWithEntry("l1_header");
+  ASSERT_NE(L, nullptr);
+
+  EXPECT_EQ(L->Entry, getBlock("l1_header"));
+  EXPECT_EQ(L->Exits.size(), 3ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("b"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("c"));
+
+  checkRegionBlocks(L,
+                    {"l1_header", "l1_body", "l1_continue",
+                     "l1_condition_false", "l1_condition_true", "a", "b", "c"},
+                    {"", "l1_end", "end", "d"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithNoConvergenceIntrinsics) {
+  StringRef Assembly = R"(
+    define void @main() "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %1 = icmp ne i32 0, 0
+      br label %l1_header
+
+    l1_header:
+      br i1 %1, label %l1_body, label %l1_end
+
+    l1_body:
+      %2 = icmp ne i32 0, 0
+      br i1 %2, label %l1_condition_true, label %l1_condition_false
+
+    l1_condition_true:
+      br label %a
+
+    a:
+      br label %end
+
+    l1_condition_false:
+      br label %l1_continue
+
+    l1_continue:
+      br label %l1_header
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+  )";
+
+  runAnalysis(Assembly).getTopLevelRegion();
+  const auto *L = getRegionWithEntry("l1_header");
+  ASSERT_NE(L, nullptr);
+
+  EXPECT_EQ(L->Entry, getBlock("l1_header"));
+  EXPECT_EQ(L->Exits.size(), 2ul);
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_header"));
+  EXPECT_THAT(L->Exits, ContainsBasicBlock("l1_body"));
+
+  checkRegionBlocks(
+      L, {"l1_header", "l1_body", "l1_continue", "l1_condition_false"},
+      {"", "l1_end", "end", "l1_condition_true", "a"});
+}
+
+TEST_F(ConvergenceRegionAnalysisTest, SimpleFunction) {
+  StringRef Assembly = R"(
+    define void @main() "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ret void
+    }
+  )";
+
+  const auto *R = runAnalysis(Assembly).getTopLevelRegion();
+  ASSERT_NE(R, nullptr);
+
+  EXPECT_EQ(R->Entry, getBlock(""));
+  EXPECT_EQ(R->Exits.size(), 1ul);
+  EXPECT_THAT(R->Exits, ContainsBasicBlock(""));
+  EXPECT_TRUE(R->contains(getBlock("")));
+}

>From ce91a9b554e842a92e274836335340682f6aff32 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Fri, 19 Jan 2024 14:23:51 +0100
Subject: [PATCH 2/4] fixup! [SPIR-V] add convergence region analysis

remove SPIRV namespace
---
 llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt |  2 +-
 ...cpp => SPIRVConvergenceRegionAnalysis.cpp} | 36 +++++++++----
 ...sis.h => SPIRVConvergenceRegionAnalysis.h} | 25 +++++----
 llvm/lib/Target/SPIRV/SPIRV.h                 |  1 +
 llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp  |  1 +
 llvm/unittests/Target/SPIRV/CMakeLists.txt    |  2 +-
 ...> SPIRVConvergenceRegionAnalysisTests.cpp} | 54 ++++++++++---------
 7 files changed, 73 insertions(+), 48 deletions(-)
 rename llvm/lib/Target/SPIRV/Analysis/{ConvergenceRegionAnalysis.cpp => SPIRVConvergenceRegionAnalysis.cpp} (88%)
 rename llvm/lib/Target/SPIRV/Analysis/{ConvergenceRegionAnalysis.h => SPIRVConvergenceRegionAnalysis.h} (90%)
 rename llvm/unittests/Target/SPIRV/{ConvergenceRegionAnalysisTests.cpp => SPIRVConvergenceRegionAnalysisTests.cpp} (94%)

diff --git a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
index 374aee3ed1c766..132d8ff838353a 100644
--- a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_llvm_component_library(LLVMSPIRVAnalysis
-  ConvergenceRegionAnalysis.cpp
+  SPIRVConvergenceRegionAnalysis.cpp
 
   LINK_COMPONENTS
   Core
diff --git a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
similarity index 88%
rename from llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp
rename to llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index 5102bc2d4228cc..1b2ded5a69ea4f 100644
--- a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -12,14 +12,28 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "ConvergenceRegionAnalysis.h"
+#include "SPIRVConvergenceRegionAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/InitializePasses.h"
 #include <optional>
 #include <queue>
 
+#define DEBUG_TYPE "spirv-convergence-region-analysis"
+
 namespace llvm {
+void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
+
+INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
+                      "convergence-region",
+                      "SPIRV convergence regions analysis", true, true);
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
+                    "convergence-region", "SPIRV convergence regions analysis",
+                    true, true);
+
 namespace SPIRV {
 
 namespace {
@@ -281,30 +295,32 @@ ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
   return Analyzer.analyze();
 }
 
-char ConvergenceRegionAnalysisWrapperPass::ID = 0;
+} // namespace SPIRV
+
+char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
 
-ConvergenceRegionAnalysisWrapperPass::ConvergenceRegionAnalysisWrapperPass()
+SPIRVConvergenceRegionAnalysisWrapperPass::
+    SPIRVConvergenceRegionAnalysisWrapperPass()
     : FunctionPass(ID) {}
 
-bool ConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
+bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
 
-  CRI = getConvergenceRegions(F, DT, LI);
+  CRI = SPIRV::getConvergenceRegions(F, DT, LI);
   // Nothing was modified.
   return false;
 }
 
-ConvergenceRegionAnalysis::Result
-ConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
+SPIRVConvergenceRegionAnalysis::Result
+SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
   Result CRI;
   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
   auto &LI = AM.getResult<LoopAnalysis>(F);
-  CRI = getConvergenceRegions(F, DT, LI);
+  CRI = SPIRV::getConvergenceRegions(F, DT, LI);
   return CRI;
 }
 
-AnalysisKey ConvergenceRegionAnalysis::Key;
+AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
 
-} // namespace SPIRV
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
similarity index 90%
rename from llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h
rename to llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
index c8cd1c4cd9ddf7..f9e30e4effa1d9 100644
--- a/llvm/lib/Target/SPIRV/Analysis/ConvergenceRegionAnalysis.h
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
@@ -1,4 +1,4 @@
-//===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
+//===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -132,14 +132,16 @@ class ConvergenceRegionInfo {
   const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
 };
 
+} // namespace SPIRV
+
 // Wrapper around the function above to use it with the legacy pass manager.
-class ConvergenceRegionAnalysisWrapperPass : public FunctionPass {
-  ConvergenceRegionInfo CRI;
+class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass {
+  SPIRV::ConvergenceRegionInfo CRI;
 
 public:
   static char ID;
 
-  ConvergenceRegionAnalysisWrapperPass();
+  SPIRVConvergenceRegionAnalysisWrapperPass();
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
@@ -149,25 +151,26 @@ class ConvergenceRegionAnalysisWrapperPass : public FunctionPass {
 
   bool runOnFunction(Function &F) override;
 
-  ConvergenceRegionInfo &getRegionInfo() { return CRI; }
-  const ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
+  SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; }
+  const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
 };
 
 // Wrapper around the function above to use it with the new pass manager.
-class ConvergenceRegionAnalysis
-    : public AnalysisInfoMixin<ConvergenceRegionAnalysis> {
-  friend AnalysisInfoMixin<ConvergenceRegionAnalysis>;
+class SPIRVConvergenceRegionAnalysis
+    : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> {
+  friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>;
   static AnalysisKey Key;
 
 public:
-  using Result = ConvergenceRegionInfo;
+  using Result = SPIRV::ConvergenceRegionInfo;
 
   Result run(Function &F, FunctionAnalysisManager &AM);
 };
 
+namespace SPIRV {
 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
                                             LoopInfo &LI);
-
 } // namespace SPIRV
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index b947062d79ea8c..9460b0808cae89 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -30,6 +30,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
                                const RegisterBankInfo &RBI);
 
 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
+void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
 void initializeSPIRVPreLegalizerPass(PassRegistry &);
 void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 3485e367dfc0fb..e1b7bdd3140dbe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -43,6 +43,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
   PassRegistry &PR = *PassRegistry::getPassRegistry();
   initializeGlobalISel(PR);
   initializeSPIRVModuleAnalysisPass(PR);
+  initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PR);
 }
 
 static std::string computeDataLayout(const Triple &TT) {
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index 8f9b81c759a00e..326a74b0cbe50e 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -12,6 +12,6 @@ set(LLVM_LINK_COMPONENTS
   )
 
 add_llvm_target_unittest(SPIRVTests
-  ConvergenceRegionAnalysisTests.cpp
+  SPIRVConvergenceRegionAnalysisTests.cpp
   )
 
diff --git a/llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
similarity index 94%
rename from llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp
rename to llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
index 2a01b00dad42a8..6e6e26c0515747 100644
--- a/llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
@@ -1,4 +1,4 @@
-//===- llvm/unittests/Target/SPIRV/ConvergenceRegionAnalysisTests.cpp -----===//
+//===- SPIRVConvergenceRegionAnalysisTests.cpp ----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "Analysis/ConvergenceRegionAnalysis.h"
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
 #include "llvm/Analysis/DominanceFrontier.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/AsmParser/Parser.h"
@@ -31,7 +31,7 @@ template <typename T> struct IsA {
   friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
 };
 
-class ConvergenceRegionAnalysisTest : public testing::Test {
+class SPIRVConvergenceRegionAnalysisTest : public testing::Test {
 protected:
   void SetUp() override {
     // Required for tests.
@@ -42,12 +42,12 @@ class ConvergenceRegionAnalysisTest : public testing::Test {
     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
     FAM.registerPass([&] { return LoopAnalysis(); });
 
-    FAM.registerPass([&] { return ConvergenceRegionAnalysis(); });
+    FAM.registerPass([&] { return SPIRVConvergenceRegionAnalysis(); });
   }
 
   void TearDown() override { M.reset(); }
 
-  ConvergenceRegionAnalysis::Result &runAnalysis(StringRef Assembly) {
+  SPIRVConvergenceRegionAnalysis::Result &runAnalysis(StringRef Assembly) {
     assert(M == nullptr &&
            "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
 
@@ -58,12 +58,12 @@ class ConvergenceRegionAnalysisTest : public testing::Test {
 
     ModulePassManager MPM;
     MPM.run(*M, MAM);
-    return FAM.getResult<ConvergenceRegionAnalysis>(*F);
+    return FAM.getResult<SPIRVConvergenceRegionAnalysis>(*F);
   }
 
-  ConvergenceRegionAnalysis::Result &getAnalysis() {
+  SPIRVConvergenceRegionAnalysis::Result &getAnalysis() {
     assert(M != nullptr && "Has runAnalysis been called before?");
-    return FAM.getResult<ConvergenceRegionAnalysis>(*getFunction());
+    return FAM.getResult<SPIRVConvergenceRegionAnalysis>(*getFunction());
   }
 
   Function *getFunction() const {
@@ -132,7 +132,7 @@ MATCHER_P(ContainsBasicBlock, label, "") {
   return false;
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, DefaultRegion) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, DefaultRegion) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       ret void
@@ -146,7 +146,7 @@ TEST_F(ConvergenceRegionAnalysisTest, DefaultRegion) {
   EXPECT_EQ(CR->Children.size(), 0u);
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, DefaultRegionWithToken) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, DefaultRegionWithToken) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -165,7 +165,7 @@ TEST_F(ConvergenceRegionAnalysisTest, DefaultRegionWithToken) {
             Intrinsic::experimental_convergence_entry);
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopOneRegion) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SingleLoopOneRegion) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -204,7 +204,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopOneRegion) {
   EXPECT_EQ(CR->Children.size(), 1u);
 }
 
-TEST_F(ConvergenceRegionAnalysisTest,
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
        SingleLoopLoopRegionParentsIsTopLevelRegion) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
@@ -242,7 +242,7 @@ TEST_F(ConvergenceRegionAnalysisTest,
   EXPECT_EQ(CR->Children[0]->ConvergenceToken.value()->getName(), "tl1");
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopExits) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SingleLoopExits) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -278,7 +278,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopExits) {
   EXPECT_THAT(L->Exits, ContainsBasicBlock("l1"));
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakExits) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SingleLoopWithBreakExits) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -330,7 +330,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakExits) {
   EXPECT_TRUE(CR->contains(getBlock("l1_condition_true")));
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakRegionBlocks) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SingleLoopWithBreakRegionBlocks) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -399,7 +399,8 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakRegionBlocks) {
 // Exact same test as before, except the 'if() break' condition in the loop is
 // not marked with any convergence intrinsic. In such case, it is valid to
 // consider it outside of the loop.
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakNoConvergenceControl) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
+       SingleLoopWithBreakNoConvergenceControl) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -452,7 +453,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithBreakNoConvergenceControl) {
   EXPECT_FALSE(L->contains(getBlock("end")));
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, TwoLoopsWithControl) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, TwoLoopsWithControl) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -543,7 +544,7 @@ TEST_F(ConvergenceRegionAnalysisTest, TwoLoopsWithControl) {
 // targets are unreachable, meaning no reachable back-edge. This should
 // transform the loop condition into a simple condition, meaning we have a
 // single convergence region.
-TEST_F(ConvergenceRegionAnalysisTest, LoopBothBranchExits) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, LoopBothBranchExits) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -592,7 +593,7 @@ TEST_F(ConvergenceRegionAnalysisTest, LoopBothBranchExits) {
   EXPECT_THAT(R->Exits, ContainsBasicBlock("end"));
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, InnerLoopBreaks) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, InnerLoopBreaks) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -679,7 +680,7 @@ TEST_F(ConvergenceRegionAnalysisTest, InnerLoopBreaks) {
        "l2_condition_true", "l1_exit", "end"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopMultipleExits) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SingleLoopMultipleExits) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -744,7 +745,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopMultipleExits) {
                      "sw.bb", "sw.bb2", "l1_exit"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest,
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
        SingleLoopMultipleExitsWithPartialConvergence) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
@@ -817,7 +818,8 @@ TEST_F(ConvergenceRegionAnalysisTest,
                     {"", "l1_end", "end", "sw.bb", "l1_exit"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceBranch) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
+       SingleLoopWithDeepConvergenceBranch) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -882,7 +884,8 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceBranch) {
                     {"", "l1_end", "end"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceLateBranch) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
+       SingleLoopWithDeepConvergenceLateBranch) {
   StringRef Assembly = R"(
     define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %t1 = call token @llvm.experimental.convergence.entry()
@@ -951,7 +954,8 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithDeepConvergenceLateBranch) {
                     {"", "l1_end", "end", "d"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithNoConvergenceIntrinsics) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest,
+       SingleLoopWithNoConvergenceIntrinsics) {
   StringRef Assembly = R"(
     define void @main() "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       %1 = icmp ne i32 0, 0
@@ -998,7 +1002,7 @@ TEST_F(ConvergenceRegionAnalysisTest, SingleLoopWithNoConvergenceIntrinsics) {
       {"", "l1_end", "end", "l1_condition_true", "a"});
 }
 
-TEST_F(ConvergenceRegionAnalysisTest, SimpleFunction) {
+TEST_F(SPIRVConvergenceRegionAnalysisTest, SimpleFunction) {
   StringRef Assembly = R"(
     define void @main() "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
       ret void

>From af06dafb5e4fb6af9c19ccc405f39359a5a53ad8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 22 Jan 2024 15:19:30 +0100
Subject: [PATCH 3/4] fixup! [SPIR-V] add convergence region analysis

PR feedback
---
 .../SPIRVConvergenceRegionAnalysis.cpp        | 38 +++++++++----------
 1 file changed, 19 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index 1b2ded5a69ea4f..e93ee6433b2f80 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -17,25 +17,30 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
 #include <optional>
 #include <queue>
 
 #define DEBUG_TYPE "spirv-convergence-region-analysis"
 
+using namespace llvm;
+
 namespace llvm {
 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
+} // namespace llvm
 
 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
                       "convergence-region",
-                      "SPIRV convergence regions analysis", true, true);
+                      "SPIRV convergence regions analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
                     "convergence-region", "SPIRV convergence regions analysis",
-                    true, true);
+                    true, true)
 
+namespace llvm {
 namespace SPIRV {
-
 namespace {
 
 template <typename BasicBlockType, typename IntrinsicInstType>
@@ -52,24 +57,19 @@ getConvergenceTokenInternal(BasicBlockType *BB) {
 
   for (auto &I : *BB) {
     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
-      if (II->getIntrinsicID() != Intrinsic::experimental_convergence_entry &&
-          II->getIntrinsicID() != Intrinsic::experimental_convergence_loop &&
-          II->getIntrinsicID() != Intrinsic::experimental_convergence_anchor) {
-        continue;
-      }
-
-      if (II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
-          II->getIntrinsicID() == Intrinsic::experimental_convergence_loop) {
+      switch (II->getIntrinsicID()) {
+      case Intrinsic::experimental_convergence_entry:
+      case Intrinsic::experimental_convergence_loop:
         return II;
+      case Intrinsic::experimental_convergence_anchor: {
+        auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
+        assert(Bundle->Inputs.size() == 1 &&
+               Bundle->Inputs[0]->getType()->isTokenTy());
+        auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
+        assert(TII != nullptr);
+        return TII;
+      }
       }
-
-      auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
-      assert(Bundle->Inputs.size() == 1 &&
-             Bundle->Inputs[0]->getType()->isTokenTy());
-      auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
-      ;
-      assert(TII != nullptr);
-      return TII;
     }
 
     if (auto *CI = dyn_cast<CallInst>(&I)) {

>From dbcf4fcc8087462bb297dce32d3087cd358b8cbc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 22 Jan 2024 17:33:02 +0100
Subject: [PATCH 4/4] fixup! [SPIR-V] add convergence region analysis

fix loop hierarchy
---
 .../SPIRVConvergenceRegionAnalysis.cpp        | 54 +++++++++----
 .../SPIRVConvergenceRegionAnalysisTests.cpp   | 80 +++++++++++++++++++
 2 files changed, 119 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index e93ee6433b2f80..7f5f7d0b1e4dc5 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -83,6 +83,33 @@ getConvergenceTokenInternal(BasicBlockType *BB) {
   return std::nullopt;
 }
 
+// Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
+// region |Entry| belongs to. If |Entry| does not belong to the region defined
+// by |Start|, this function returns |nullptr|.
+ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
+                                    BasicBlock *Entry) {
+  ConvergenceRegion *Candidate = nullptr;
+  ConvergenceRegion *NextCandidate = Start;
+
+  while (Candidate != NextCandidate && NextCandidate != nullptr) {
+    Candidate = NextCandidate;
+    NextCandidate = nullptr;
+
+    // End of the search, we can return.
+    if (Candidate->Children.size() == 0)
+      return Candidate;
+
+    for (auto *Child : Candidate->Children) {
+      if (Child->Blocks.count(Entry) != 0) {
+        NextCandidate = Child;
+        break;
+      }
+    }
+  }
+
+  return Candidate;
+}
+
 } // anonymous namespace
 
 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
@@ -193,7 +220,7 @@ class ConvergenceRegionAnalyzer {
   }
 
   std::unordered_set<BasicBlock *>
-  findPathsToMatch(BasicBlock *From,
+  findPathsToMatch(LoopInfo &LI, BasicBlock *From,
                    std::function<bool(const BasicBlock *)> isMatch) const {
     std::unordered_set<BasicBlock *> Output;
 
@@ -206,12 +233,18 @@ class ConvergenceRegionAnalyzer {
       if (isBackEdge(From, To))
         continue;
 
-      auto ChildSet = findPathsToMatch(To, isMatch);
+      auto ChildSet = findPathsToMatch(LI, To, isMatch);
       if (ChildSet.size() == 0)
         continue;
 
       Output.insert(ChildSet.begin(), ChildSet.end());
       Output.insert(From);
+      if (LI.isLoopHeader(From)) {
+        auto *L = LI.getLoopFor(From);
+        for (auto *BB : L->getBlocks()) {
+          Output.insert(BB);
+        }
+      }
     }
 
     return Output;
@@ -236,18 +269,13 @@ class ConvergenceRegionAnalyzer {
 public:
   ConvergenceRegionInfo analyze() {
     ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
-
-    std::unordered_map<Loop *, ConvergenceRegion *> LoopToRegion;
     std::queue<Loop *> ToProcess;
-    for (auto *L : LI)
+    for (auto *L : LI.getLoopsInPreorder())
       ToProcess.push(L);
 
     while (ToProcess.size() != 0) {
       auto *L = ToProcess.front();
       ToProcess.pop();
-      for (auto *Child : *L)
-        ToProcess.push(Child);
-
       assert(L->isLoopSimplifyForm());
 
       auto CT = getConvergenceToken(L->getHeader());
@@ -257,7 +285,7 @@ class ConvergenceRegionAnalyzer {
       L->getExitingBlocks(LoopExits);
       if (CT.has_value()) {
         for (auto *Exit : LoopExits) {
-          auto N = findPathsToMatch(Exit, [&CT](const BasicBlock *block) {
+          auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
             auto Token = getConvergenceToken(block);
             if (Token == std::nullopt)
               return false;
@@ -271,13 +299,9 @@ class ConvergenceRegionAnalyzer {
       ConvergenceRegion *Region = new ConvergenceRegion(
           DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
           std::move(RegionExits));
-
-      auto It = LoopToRegion.find(L->getParentLoop());
-      assert(It != LoopToRegion.end() || L->getParentLoop() == nullptr);
-      Region->Parent = It != LoopToRegion.end() ? It->second : TopLevelRegion;
+      Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
+      assert(Region->Parent != nullptr && "This is impossible.");
       Region->Parent->Children.push_back(Region);
-
-      LoopToRegion.emplace(L, Region);
     }
 
     return ConvergenceRegionInfo(TopLevelRegion);
diff --git a/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
index 6e6e26c0515747..e04fc85df4f93b 100644
--- a/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVConvergenceRegionAnalysisTests.cpp
@@ -1017,3 +1017,83 @@ TEST_F(SPIRVConvergenceRegionAnalysisTest, SimpleFunction) {
   EXPECT_THAT(R->Exits, ContainsBasicBlock(""));
   EXPECT_TRUE(R->contains(getBlock("")));
 }
+
+TEST_F(SPIRVConvergenceRegionAnalysisTest, NestedLoopInBreak) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %t1 = call token @llvm.experimental.convergence.entry()
+      %1 = icmp ne i32 0, 0
+      br label %l1
+
+    l1:
+      %tl1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %t1) ]
+      br i1 %1, label %l1_body, label %l1_to_end
+
+    l1_body:
+      br i1 %1, label %cond_inner, label %l1_continue
+
+    cond_inner:
+      br label %l2
+
+    l2:
+      %tl2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %tl1) ]
+      br i1 %1, label %l2_body, label %l2_end
+
+    l2_body:
+      %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl2) ]
+      br label %l2_continue
+
+    l2_continue:
+      br label %l2
+
+    l2_end:
+      br label %l2_exit
+
+    l2_exit:
+      %call2 = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+      br label %l1_end
+
+    l1_continue:
+      br label %l1
+
+    l1_to_end:
+      br label %l1_end
+
+    l1_end:
+      br label %end
+
+    end:
+      ret void
+    }
+
+    declare token @llvm.experimental.convergence.entry()
+    declare token @llvm.experimental.convergence.control()
+    declare token @llvm.experimental.convergence.loop()
+    declare spir_func i32 @_Z3absi(i32) convergent
+  )";
+
+  const auto *R = runAnalysis(Assembly).getTopLevelRegion();
+  ASSERT_NE(R, nullptr);
+
+  EXPECT_EQ(R->Children.size(), 1ul);
+
+  const auto *L1 = R->Children[0];
+  EXPECT_EQ(L1->Children.size(), 1ul);
+  EXPECT_EQ(L1->Entry->getName(), "l1");
+  EXPECT_EQ(L1->Exits.size(), 2ul);
+  EXPECT_THAT(L1->Exits, ContainsBasicBlock("l1"));
+  EXPECT_THAT(L1->Exits, ContainsBasicBlock("l2_exit"));
+  checkRegionBlocks(L1,
+                    {"l1", "l1_body", "l1_continue", "cond_inner", "l2",
+                     "l2_body", "l2_end", "l2_continue", "l2_exit"},
+                    {"", "l1_to_end", "l1_end", "end"});
+
+  const auto *L2 = L1->Children[0];
+  EXPECT_EQ(L2->Children.size(), 0ul);
+  EXPECT_EQ(L2->Entry->getName(), "l2");
+  EXPECT_EQ(L2->Exits.size(), 1ul);
+  EXPECT_THAT(L2->Exits, ContainsBasicBlock("l2"));
+  checkRegionBlocks(L2, {"l2", "l2_body", "l2_continue"},
+                    {"", "l1_to_end", "l1_end", "end", "l1", "l1_body",
+                     "l1_continue", "cond_inner", "l2_end", "l2_exit"});
+}



More information about the llvm-commits mailing list