[llvm] 466bd99 - Revert "[LLVM] move verification of convergence control to a class template"

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 1 04:52:29 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-08-01T17:00:39+05:30
New Revision: 466bd9981150906552a1f2308e3c9065bfcb6741

URL: https://github.com/llvm/llvm-project/commit/466bd9981150906552a1f2308e3c9065bfcb6741
DIFF: https://github.com/llvm/llvm-project/commit/466bd9981150906552a1f2308e3c9065bfcb6741.diff

LOG: Revert "[LLVM] move verification of convergence control to a class template"

This reverts commit 93a3706711fd46d4d487640d91b16c2eec747c9e.

The "extern template" declaration of CycleInfo caused problems in a shared build
when CycleInfo was removed from Verifier.cpp. There needs to be an explicit
instantiation corresponding to an extern template in every SO.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/GenericSSAContext.h
    llvm/lib/CodeGen/MachineSSAContext.cpp
    llvm/lib/IR/CMakeLists.txt
    llvm/lib/IR/SSAContext.cpp
    llvm/lib/IR/Verifier.cpp

Removed: 
    llvm/include/llvm/ADT/GenericConvergenceVerifier.h
    llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
    llvm/include/llvm/IR/ConvergenceVerifier.h
    llvm/lib/IR/ConvergenceVerifier.cpp


################################################################################
diff  --git a/llvm/include/llvm/ADT/GenericConvergenceVerifier.h b/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
deleted file mode 100644
index ae5e9e85856390..00000000000000
--- a/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
+++ /dev/null
@@ -1,75 +0,0 @@
-//===- GenericConvergenceVerifier.h ---------------------------*- C++ -*---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-///
-/// A verifer for the static rules of convergence control tokens that works with
-/// both LLVM IR and MIR.
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_ADT_GENERICCONVERGENCEVERIFIER_H
-#define LLVM_ADT_GENERICCONVERGENCEVERIFIER_H
-
-#include "llvm/ADT/GenericCycleInfo.h"
-
-namespace llvm {
-
-template <typename ContextT> class GenericConvergenceVerifier {
-public:
-  using BlockT = typename ContextT::BlockT;
-  using FunctionT = typename ContextT::FunctionT;
-  using ValueRefT = typename ContextT::ValueRefT;
-  using InstructionT = typename ContextT::InstructionT;
-  using DominatorTreeT = typename ContextT::DominatorTreeT;
-  using CycleInfoT = GenericCycleInfo<ContextT>;
-  using CycleT = typename CycleInfoT::CycleT;
-
-  void initialize(raw_ostream *OS,
-                  function_ref<void(const Twine &Message)> FailureCB,
-                  const FunctionT &F) {
-    clear();
-    this->OS = OS;
-    this->FailureCB = FailureCB;
-    Context = ContextT(&F);
-  }
-
-  void clear();
-  void visit(const InstructionT &I);
-  void verify(const DominatorTreeT &DT);
-
-  bool sawTokens() const { return ConvergenceKind == ControlledConvergence; }
-
-private:
-  raw_ostream *OS;
-  std::function<void(const Twine &Message)> FailureCB;
-  DominatorTreeT *DT;
-  CycleInfoT CI;
-  ContextT Context;
-
-  /// Whether the current function has convergencectrl operand bundles.
-  enum {
-    ControlledConvergence,
-    UncontrolledConvergence,
-    NoConvergence
-  } ConvergenceKind = NoConvergence;
-
-  // Cache token uses found so far. Note that we track the unique definitions
-  // and not the token values.
-  DenseMap<const InstructionT *, const InstructionT *> Tokens;
-
-  const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
-  bool isControlledConvergent(const InstructionT &I);
-  bool isConvergent(const InstructionT &I) const;
-
-  void reportFailure(const Twine &Message, ArrayRef<Printable> Values);
-};
-
-} // end namespace llvm
-
-#endif // LLVM_ADT_GENERICCONVERGENCEVERIFIER_H

diff  --git a/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h b/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
deleted file mode 100644
index b049f80091a212..00000000000000
--- a/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
+++ /dev/null
@@ -1,204 +0,0 @@
-//===- GenericConvergenceVerifierImpl.h -----------------------*- C++ -*---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-///
-/// A verifer for the static rules of convergence control tokens that works with
-/// both LLVM IR and MIR.
-///
-/// This template implementation resides in a separate file so that it does not
-/// get injected into every .cpp file that includes the generic header.
-///
-/// DO NOT INCLUDE THIS FILE WHEN MERELY USING CYCLEINFO.
-///
-/// This file should only be included by files that implement a
-/// specialization of the relevant templates. Currently these are:
-/// - llvm/lib/IR/Verifier.cpp
-/// - llvm/lib/CodeGen/MachineVerifier.cpp
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H
-#define LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H
-
-#include "llvm/ADT/GenericConvergenceVerifier.h"
-#include "llvm/ADT/GenericCycleInfo.h"
-#include "llvm/ADT/GenericSSAContext.h"
-#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/IR/Intrinsics.h"
-
-using namespace llvm;
-
-#define Check(C, ...)                                                          \
-  do {                                                                         \
-    if (!(C)) {                                                                \
-      reportFailure(__VA_ARGS__);                                              \
-      return;                                                                  \
-    }                                                                          \
-  } while (false)
-
-#define CheckOrNull(C, ...)                                                    \
-  do {                                                                         \
-    if (!(C)) {                                                                \
-      reportFailure(__VA_ARGS__);                                              \
-      return {};                                                               \
-    }                                                                          \
-  } while (false)
-
-static bool isConvergenceControlIntrinsic(unsigned IntrinsicID) {
-  switch (IntrinsicID) {
-  default:
-    return false;
-  case Intrinsic::experimental_convergence_anchor:
-  case Intrinsic::experimental_convergence_entry:
-  case Intrinsic::experimental_convergence_loop:
-    return true;
-  }
-}
-
-namespace llvm {
-template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
-  Tokens.clear();
-  CI.clear();
-  ConvergenceKind = NoConvergence;
-}
-
-template <class ContextT>
-void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
-  if (isControlledConvergent(I)) {
-    Check(isConvergent(I),
-          "Expected convergent attribute on a controlled convergent call.",
-          {Context.print(&I)});
-    Check(ConvergenceKind != UncontrolledConvergence,
-          "Cannot mix controlled and uncontrolled convergence in the same "
-          "function.",
-          {Context.print(&I)});
-    ConvergenceKind = ControlledConvergence;
-  } else if (isConvergent(I)) {
-    Check(ConvergenceKind != ControlledConvergence,
-          "Cannot mix controlled and uncontrolled convergence in the same "
-          "function.",
-          {Context.print(&I)});
-    ConvergenceKind = UncontrolledConvergence;
-  }
-}
-
-template <class ContextT>
-void GenericConvergenceVerifier<ContextT>::reportFailure(
-    const Twine &Message, ArrayRef<Printable> DumpedValues) {
-  FailureCB(Message);
-  if (OS) {
-    for (auto V : DumpedValues)
-      *OS << V << '\n';
-  }
-}
-
-template <class ContextT>
-void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
-  assert(Context.getFunction());
-  const auto &F = *Context.getFunction();
-
-  DenseMap<const BlockT *, SmallVector<const InstructionT *, 8>> LiveTokenMap;
-  DenseMap<const CycleT *, const InstructionT *> CycleHearts;
-
-  // Just like the DominatorTree, compute the CycleInfo locally so that we
-  // can run the verifier outside of a pass manager and we don't rely on
-  // potentially out-dated analysis results.
-  CI.compute(const_cast<FunctionT &>(F));
-
-  auto checkToken = [&](const InstructionT *Token, const InstructionT *User,
-                        SmallVectorImpl<const InstructionT *> &LiveTokens) {
-    Check(llvm::is_contained(LiveTokens, Token),
-          "Convergence region is not well-nested.",
-          {Context.print(Token), Context.print(User)});
-    while (LiveTokens.back() != Token)
-      LiveTokens.pop_back();
-
-    // Check static rules about cycles.
-    auto *BB = User->getParent();
-    auto *BBCycle = CI.getCycle(BB);
-    if (!BBCycle)
-      return;
-
-    auto *DefBB = Token->getParent();
-    if (DefBB == BB || BBCycle->contains(DefBB)) {
-      // degenerate occurrence of a loop intrinsic
-      return;
-    }
-
-    Check(ContextT::getIntrinsicID(*User) ==
-              Intrinsic::experimental_convergence_loop,
-          "Convergence token used by an instruction other than "
-          "llvm.experimental.convergence.loop in a cycle that does "
-          "not contain the token's definition.",
-          {Context.print(User), CI.print(BBCycle)});
-
-    while (true) {
-      auto *Parent = BBCycle->getParentCycle();
-      if (!Parent || Parent->contains(DefBB))
-        break;
-      BBCycle = Parent;
-    };
-
-    Check(BBCycle->isReducible() && BB == BBCycle->getHeader(),
-          "Cycle heart must dominate all blocks in the cycle.",
-          {Context.print(User), Context.printAsOperand(BB), CI.print(BBCycle)});
-    Check(!CycleHearts.count(BBCycle),
-          "Two static convergence token uses in a cycle that does "
-          "not contain either token's definition.",
-          {Context.print(User), Context.print(CycleHearts[BBCycle]),
-           CI.print(BBCycle)});
-    CycleHearts[BBCycle] = User;
-  };
-
-  ReversePostOrderTraversal<const FunctionT *> RPOT(&F);
-  SmallVector<const InstructionT *, 8> LiveTokens;
-  for (auto *BB : RPOT) {
-    LiveTokens.clear();
-    auto LTIt = LiveTokenMap.find(BB);
-    if (LTIt != LiveTokenMap.end()) {
-      LiveTokens = std::move(LTIt->second);
-      LiveTokenMap.erase(LTIt);
-    }
-
-    for (auto &I : *BB) {
-      if (auto *Token = Tokens.lookup(&I))
-        checkToken(Token, &I, LiveTokens);
-      if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
-        LiveTokens.push_back(&I);
-    }
-
-    // Propagate token liveness
-    for (auto *Succ : successors(BB)) {
-      auto *SuccNode = DT.getNode(Succ);
-      auto LTIt = LiveTokenMap.find(Succ);
-      if (LTIt == LiveTokenMap.end()) {
-        // We're the first predecessor: all tokens which dominate the
-        // successor are live for now.
-        LTIt = LiveTokenMap.try_emplace(Succ).first;
-        for (auto LiveToken : LiveTokens) {
-          if (!DT.dominates(DT.getNode(LiveToken->getParent()), SuccNode))
-            break;
-          LTIt->second.push_back(LiveToken);
-        }
-      } else {
-        // Compute the intersection of live tokens.
-        auto It = llvm::partition(
-            LTIt->second, [&LiveTokens](const InstructionT *Token) {
-              return llvm::is_contained(LiveTokens, Token);
-            });
-        LTIt->second.erase(It, LTIt->second.end());
-      }
-    }
-  }
-}
-
-} // end namespace llvm
-
-#endif // LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H

diff  --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b6..45f5dc7d3fcd1d 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -24,10 +24,6 @@ namespace llvm {
 template <typename, bool> class DominatorTreeBase;
 template <typename> class SmallVectorImpl;
 
-namespace Intrinsic {
-typedef unsigned ID;
-}
-
 // Specializations of this template should provide the types used by the
 // template GenericSSAContext below.
 template <typename _FunctionT> struct GenericSSATraits;
@@ -82,8 +78,6 @@ template <typename _FunctionT> class GenericSSAContext {
 
   const FunctionT *getFunction() const { return F; }
 
-  static Intrinsic::ID getIntrinsicID(const InstructionT &I);
-
   static void appendBlockDefs(SmallVectorImpl<ValueRefT> &defs, BlockT &block);
   static void appendBlockDefs(SmallVectorImpl<ConstValueRefT> &defs,
                               const BlockT &block);
@@ -97,7 +91,6 @@ template <typename _FunctionT> class GenericSSAContext {
   const BlockT *getDefBlock(ConstValueRefT value) const;
 
   Printable print(const BlockT *block) const;
-  Printable printAsOperand(const BlockT *BB) const;
   Printable print(const InstructionT *inst) const;
   Printable print(ConstValueRefT value) const;
 };

diff  --git a/llvm/include/llvm/IR/ConvergenceVerifier.h b/llvm/include/llvm/IR/ConvergenceVerifier.h
deleted file mode 100644
index df4c495f6a66fd..00000000000000
--- a/llvm/include/llvm/IR/ConvergenceVerifier.h
+++ /dev/null
@@ -1,28 +0,0 @@
-//===- ConvergenceVerifier.h - Verify convergenctrl -------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-/// \file
-///
-/// This file declares the LLVM IR specialization of the
-/// GenericConvergenceVerifier template.
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_IR_CONVERGENCEVERIFIER_H
-#define LLVM_IR_CONVERGENCEVERIFIER_H
-
-#include "llvm/ADT/GenericConvergenceVerifier.h"
-#include "llvm/IR/SSAContext.h"
-
-namespace llvm {
-
-extern template class GenericConvergenceVerifier<SSAContext>;
-using ConvergenceVerifier = GenericConvergenceVerifier<SSAContext>;
-
-} // namespace llvm
-
-#endif // LLVM_IR_CONVERGENCEVERIFIER_H

diff  --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index e384187b6e8593..4311255ed76d04 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -13,7 +13,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/MachineSSAContext.h"
-#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstr.h"
@@ -59,13 +58,6 @@ bool MachineSSAContext::isConstantOrUndefValuePhi(const MachineInstr &Phi) {
   return Phi.isConstantValuePHI();
 }
 
-template <>
-Intrinsic::ID MachineSSAContext::getIntrinsicID(const MachineInstr &MI) {
-  if (auto *GI = dyn_cast<GIntrinsic>(&MI))
-    return GI->getIntrinsicID();
-  return Intrinsic::not_intrinsic;
-}
-
 template <>
 Printable MachineSSAContext::print(const MachineBasicBlock *Block) const {
   if (!Block)
@@ -91,8 +83,3 @@ template <> Printable MachineSSAContext::print(Register Value) const {
     }
   });
 }
-
-template <>
-Printable MachineSSAContext::printAsOperand(const MachineBasicBlock *BB) const {
-  return Printable([BB](raw_ostream &Out) { BB->printAsOperand(Out); });
-}

diff  --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index d9656a24d0ed3f..217fe703dd4eef 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -10,7 +10,6 @@ add_llvm_component_library(LLVMCore
   ConstantFold.cpp
   ConstantRange.cpp
   Constants.cpp
-  ConvergenceVerifier.cpp
   Core.cpp
   CycleInfo.cpp
   DIBuilder.cpp

diff  --git a/llvm/lib/IR/ConvergenceVerifier.cpp b/llvm/lib/IR/ConvergenceVerifier.cpp
deleted file mode 100644
index e1c15a3a840837..00000000000000
--- a/llvm/lib/IR/ConvergenceVerifier.cpp
+++ /dev/null
@@ -1,83 +0,0 @@
-//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// NOTE: Including the following header causes a premature instantiation of the
-// template, and the compiler complains about explicit specialization after
-// instantiation. So don't include it.
-//
-// #include "llvm/IR/ConvergenceVerifier.h"
-//===----------------------------------------------------------------------===//
-
-#include "llvm/ADT/GenericConvergenceVerifierImpl.h"
-#include "llvm/IR/Dominators.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/SSAContext.h"
-
-using namespace llvm;
-
-template <>
-const Instruction *
-GenericConvergenceVerifier<SSAContext>::findAndCheckConvergenceTokenUsed(
-    const Instruction &I) {
-  auto *CB = dyn_cast<CallBase>(&I);
-  if (!CB)
-    return nullptr;
-
-  unsigned Count =
-      CB->countOperandBundlesOfType(LLVMContext::OB_convergencectrl);
-  CheckOrNull(Count <= 1,
-              "The 'convergencetrl' bundle can occur at most once on a call",
-              {Context.print(CB)});
-  if (!Count)
-    return nullptr;
-
-  auto Bundle = CB->getOperandBundle(LLVMContext::OB_convergencectrl);
-  CheckOrNull(Bundle->Inputs.size() == 1 &&
-                  Bundle->Inputs[0]->getType()->isTokenTy(),
-              "The 'convergencectrl' bundle requires exactly one token use.",
-              {Context.print(CB)});
-  auto *Token = Bundle->Inputs[0].get();
-  auto *Def = dyn_cast<Instruction>(Token);
-
-  CheckOrNull(
-      Def && isConvergenceControlIntrinsic(SSAContext::getIntrinsicID(*Def)),
-      "Convergence control tokens can only be produced by calls to the "
-      "convergence control intrinsics.",
-      {Context.print(Token), Context.print(&I)});
-
-  if (Def)
-    Tokens[&I] = Def;
-
-  return Def;
-}
-
-template <>
-bool GenericConvergenceVerifier<SSAContext>::isConvergent(
-    const InstructionT &I) const {
-  if (auto *CB = dyn_cast<CallBase>(&I)) {
-    return CB->isConvergent();
-  }
-  return false;
-}
-
-template <>
-bool GenericConvergenceVerifier<SSAContext>::isControlledConvergent(
-    const InstructionT &I) {
-  // First find a token and place it in the map.
-  if (findAndCheckConvergenceTokenUsed(I))
-    return true;
-
-  // The entry and anchor intrinsics do not use a token, so we do a broad check
-  // here. The loop intrinsic will be checked separately for a missing token.
-  if (isConvergenceControlIntrinsic(SSAContext::getIntrinsicID(I)))
-    return true;
-
-  return false;
-}
-
-template class llvm::GenericConvergenceVerifier<SSAContext>;

diff  --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 220abe3083ebd7..3a5c4bf4aa30c0 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -16,10 +16,10 @@
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
-#include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/ModuleSlotTracker.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/IR/ModuleSlotTracker.h"
 
 using namespace llvm;
 
@@ -69,12 +69,6 @@ bool SSAContext::isConstantOrUndefValuePhi(const Instruction &Instr) {
   return false;
 }
 
-template <> Intrinsic::ID SSAContext::getIntrinsicID(const Instruction &I) {
-  if (auto *CB = dyn_cast<CallBase>(&I))
-    return CB->getIntrinsicID();
-  return Intrinsic::not_intrinsic;
-}
-
 template <> Printable SSAContext::print(const Value *V) const {
   return Printable([V](raw_ostream &Out) { V->print(Out); });
 }
@@ -95,7 +89,3 @@ template <> Printable SSAContext::print(const BasicBlock *BB) const {
     Out << MST.getLocalSlot(BB);
   });
 }
-
-template <> Printable SSAContext::printAsOperand(const BasicBlock *BB) const {
-  return Printable([BB](raw_ostream &Out) { BB->printAsOperand(Out); });
-}

diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 2374aeb8c1d666..21acb4185f7a8e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -73,7 +73,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
-#include "llvm/IR/ConvergenceVerifier.h"
+#include "llvm/IR/CycleInfo.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DebugInfoMetadata.h"
@@ -329,6 +329,13 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   /// The current source language.
   dwarf::SourceLanguage CurrentSourceLang = dwarf::DW_LANG_lo_user;
 
+  /// Whether the current function has convergencectrl operand bundles.
+  enum {
+    ControlledConvergence,
+    UncontrolledConvergence,
+    NoConvergence
+  } ConvergenceKind = NoConvergence;
+
   /// Whether source was present on the first DIFile encountered in each CU.
   DenseMap<const DICompileUnit *, bool> HasSourceDebugInfo;
 
@@ -363,7 +370,6 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   SmallVector<const DILocalVariable *, 16> DebugFnArgs;
 
   TBAAVerifier TBAAVerifyHelper;
-  ConvergenceVerifier CV;
 
   SmallVector<IntrinsicInst *, 4> NoAliasScopeDecls;
 
@@ -405,19 +411,12 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
       return false;
     }
 
-    auto FailureCB = [this](const Twine &Message) {
-      this->CheckFailed(Message);
-    };
-    CV.initialize(OS, FailureCB, F);
-
     Broken = false;
     // FIXME: We strip const here because the inst visitor strips const.
     visit(const_cast<Function &>(F));
     verifySiblingFuncletUnwinds();
-
-    if (CV.sawTokens())
-      CV.verify(DT);
-
+    if (ConvergenceKind == ControlledConvergence)
+      verifyConvergenceControl(const_cast<Function &>(F));
     InstsInThisBlock.clear();
     DebugFnArgs.clear();
     LandingPadResultTy = nullptr;
@@ -425,6 +424,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
     SiblingFuncletInfo.clear();
     verifyNoAliasScopeDecl();
     NoAliasScopeDecls.clear();
+    ConvergenceKind = NoConvergence;
 
     return !Broken;
   }
@@ -600,6 +600,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
   void verifySiblingFuncletUnwinds();
+  void verifyConvergenceControl(Function &F);
 
   void verifyFragmentExpression(const DbgVariableIntrinsic &I);
   template <typename ValueOrMetadata>
@@ -2534,6 +2535,138 @@ void Verifier::verifySiblingFuncletUnwinds() {
   }
 }
 
+static bool isConvergenceControlIntrinsic(const CallBase &Call) {
+  switch (Call.getIntrinsicID()) {
+  case Intrinsic::experimental_convergence_anchor:
+  case Intrinsic::experimental_convergence_entry:
+  case Intrinsic::experimental_convergence_loop:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static bool isControlledConvergent(const CallBase &Call) {
+  if (Call.countOperandBundlesOfType(LLVMContext::OB_convergencectrl))
+    return true;
+  return isConvergenceControlIntrinsic(Call);
+}
+
+void Verifier::verifyConvergenceControl(Function &F) {
+  DenseMap<BasicBlock *, SmallVector<CallBase *, 8>> LiveTokenMap;
+  DenseMap<const Cycle *, const CallBase *> CycleHearts;
+
+  // Just like the DominatorTree, compute the CycleInfo locally so that we
+  // can run the verifier outside of a pass manager and we don't rely on
+  // potentially out-dated analysis results.
+  CycleInfo CI;
+  CI.compute(F);
+
+  auto checkBundle = [&](OperandBundleUse &Bundle, CallBase *CB,
+                         SmallVectorImpl<CallBase *> &LiveTokens) {
+    Check(Bundle.Inputs.size() == 1 && Bundle.Inputs[0]->getType()->isTokenTy(),
+          "The 'convergencectrl' bundle requires exactly one token use.", CB);
+
+    Value *Token = Bundle.Inputs[0].get();
+    auto *Def = dyn_cast<CallBase>(Token);
+    Check(Def && isConvergenceControlIntrinsic(*Def),
+          "Convergence control tokens can only be produced by calls to the "
+          "convergence control intrinsics.",
+          Token, CB);
+
+    Check(llvm::is_contained(LiveTokens, Token),
+          "Convergence region is not well-nested.", Token, CB);
+
+    while (LiveTokens.back() != Token)
+      LiveTokens.pop_back();
+
+    // Check static rules about cycles.
+    auto *BB = CB->getParent();
+    auto *BBCycle = CI.getCycle(BB);
+    if (!BBCycle)
+      return;
+
+    BasicBlock *DefBB = Def->getParent();
+    if (DefBB == BB || BBCycle->contains(DefBB)) {
+      // degenerate occurrence of a loop intrinsic
+      return;
+    }
+
+    auto *II = dyn_cast<IntrinsicInst>(CB);
+    Check(II &&
+              II->getIntrinsicID() == Intrinsic::experimental_convergence_loop,
+          "Convergence token used by an instruction other than "
+          "llvm.experimental.convergence.loop in a cycle that does "
+          "not contain the token's definition.",
+          CB, CI.print(BBCycle));
+
+    while (true) {
+      auto *Parent = BBCycle->getParentCycle();
+      if (!Parent || Parent->contains(DefBB))
+        break;
+      BBCycle = Parent;
+    };
+
+    Check(BBCycle->isReducible() && BB == BBCycle->getHeader(),
+          "Cycle heart must dominate all blocks in the cycle.", CB, BB,
+          CI.print(BBCycle));
+    Check(!CycleHearts.count(BBCycle),
+          "Two static convergence token uses in a cycle that does "
+          "not contain either token's definition.",
+          CB, CycleHearts[BBCycle], CI.print(BBCycle));
+    CycleHearts[BBCycle] = CB;
+  };
+
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  SmallVector<CallBase *, 8> LiveTokens;
+  for (BasicBlock *BB : RPOT) {
+    LiveTokens.clear();
+    auto LTIt = LiveTokenMap.find(BB);
+    if (LTIt != LiveTokenMap.end()) {
+      LiveTokens = std::move(LTIt->second);
+      LiveTokenMap.erase(LTIt);
+    }
+
+    for (Instruction &I : *BB) {
+      CallBase *CB = dyn_cast<CallBase>(&I);
+      if (!CB)
+        continue;
+
+      Check(CB->countOperandBundlesOfType(LLVMContext::OB_convergencectrl) <= 1,
+            "The 'convergencetrl' bundle can occur at most once on a call", CB);
+
+      auto Bundle = CB->getOperandBundle(LLVMContext::OB_convergencectrl);
+      if (Bundle)
+        checkBundle(*Bundle, CB, LiveTokens);
+
+      if (CB->getType()->isTokenTy())
+        LiveTokens.push_back(CB);
+    }
+
+    // Propagate token liveness
+    for (BasicBlock *Succ : successors(BB)) {
+      DomTreeNode *SuccNode = DT.getNode(Succ);
+      LTIt = LiveTokenMap.find(Succ);
+      if (LTIt == LiveTokenMap.end()) {
+        // We're the first predecessor: all tokens which dominate the
+        // successor are live for now.
+        LTIt = LiveTokenMap.try_emplace(Succ).first;
+        for (CallBase *LiveToken : LiveTokens) {
+          if (!DT.dominates(DT.getNode(LiveToken->getParent()), SuccNode))
+            break;
+          LTIt->second.push_back(LiveToken);
+        }
+      } else {
+        // Compute the intersection of live tokens.
+        auto It = llvm::partition(LTIt->second, [&LiveTokens](CallBase *Token) {
+          return llvm::is_contained(LiveTokens, Token);
+        });
+        LTIt->second.erase(It, LTIt->second.end());
+      }
+    }
+  }
+}
+
 // visitFunction - Verify that a function is ok.
 //
 void Verifier::visitFunction(const Function &F) {
@@ -3555,7 +3688,22 @@ void Verifier::visitCallBase(CallBase &Call) {
   if (Call.isInlineAsm())
     verifyInlineAsmCall(Call);
 
-  CV.visit(Call);
+  if (isControlledConvergent(Call)) {
+    Check(Call.isConvergent(),
+          "Expected convergent attribute on a controlled convergent call.",
+          Call);
+    Check(ConvergenceKind != UncontrolledConvergence,
+          "Cannot mix controlled and uncontrolled convergence in the same "
+          "function.",
+          Call);
+    ConvergenceKind = ControlledConvergence;
+  } else if (Call.isConvergent()) {
+    Check(ConvergenceKind != ControlledConvergence,
+          "Cannot mix controlled and uncontrolled convergence in the same "
+          "function.",
+          Call);
+    ConvergenceKind = UncontrolledConvergence;
+  }
 
   visitInstruction(Call);
 }


        


More information about the llvm-commits mailing list