[llvm] 93a3706 - [LLVM] move verification of convergence control to a class template

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 22:52:39 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-08-01T11:21:48+05:30
New Revision: 93a3706711fd46d4d487640d91b16c2eec747c9e

URL: https://github.com/llvm/llvm-project/commit/93a3706711fd46d4d487640d91b16c2eec747c9e
DIFF: https://github.com/llvm/llvm-project/commit/93a3706711fd46d4d487640d91b16c2eec747c9e.diff

LOG: [LLVM] move verification of convergence control to a class template

The refactored template can now be used with MachineVerifier.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D156522

Added: 
    llvm/include/llvm/ADT/GenericConvergenceVerifier.h
    llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
    llvm/include/llvm/IR/ConvergenceVerifier.h
    llvm/lib/IR/ConvergenceVerifier.cpp

Modified: 
    llvm/include/llvm/ADT/GenericSSAContext.h
    llvm/lib/CodeGen/MachineSSAContext.cpp
    llvm/lib/IR/CMakeLists.txt
    llvm/lib/IR/SSAContext.cpp
    llvm/lib/IR/Verifier.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/GenericConvergenceVerifier.h b/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
new file mode 100644
index 00000000000000..213d77593c9968
--- /dev/null
+++ b/llvm/include/llvm/ADT/GenericConvergenceVerifier.h
@@ -0,0 +1,75 @@
+//===- GenericConvergenceVerifier.h ---------------------------*- C++ -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+///
+/// A verifer for the static rules of convergence control tokens that works with
+/// both LLVM IR and MIR.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_GENERICCONVERGENCEVERIFIER_H
+#define LLVM_ADT_GENERICCONVERGENCEVERIFIER_H
+
+#include "llvm/ADT/GenericCycleInfo.h"
+
+namespace llvm {
+
+template <typename ContextT> class GenericConvergenceVerifier {
+public:
+  using BlockT = typename ContextT::BlockT;
+  using FunctionT = typename ContextT::FunctionT;
+  using ValueRefT = typename ContextT::ValueRefT;
+  using InstructionT = typename ContextT::InstructionT;
+  using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using CycleInfoT = GenericCycleInfo<ContextT>;
+  using CycleT = typename CycleInfoT::CycleT;
+
+  void initialize(raw_ostream &OS,
+                  function_ref<void(const Twine &Message)> FailureCB,
+                  const FunctionT &F) {
+    clear();
+    this->OS = &OS;
+    this->FailureCB = FailureCB;
+    Context = ContextT(&F);
+  }
+
+  void clear();
+  void visit(const InstructionT &I);
+  void verify(const DominatorTreeT &DT);
+
+  bool sawTokens() const { return ConvergenceKind == ControlledConvergence; }
+
+private:
+  raw_ostream *OS;
+  std::function<void(const Twine &Message)> FailureCB;
+  DominatorTreeT *DT;
+  CycleInfoT CI;
+  ContextT Context;
+
+  /// Whether the current function has convergencectrl operand bundles.
+  enum {
+    ControlledConvergence,
+    UncontrolledConvergence,
+    NoConvergence
+  } ConvergenceKind = NoConvergence;
+
+  // Cache token uses found so far. Note that we track the unique definitions
+  // and not the token values.
+  DenseMap<const InstructionT *, const InstructionT *> Tokens;
+
+  const InstructionT *findAndCheckConvergenceTokenUsed(const InstructionT &I);
+  bool isControlledConvergent(const InstructionT &I);
+  bool isConvergent(const InstructionT &I) const;
+
+  void reportFailure(const Twine &Message, ArrayRef<Printable> Values);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_ADT_GENERICCONVERGENCEVERIFIER_H

diff  --git a/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h b/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
new file mode 100644
index 00000000000000..5a50ddf256ca69
--- /dev/null
+++ b/llvm/include/llvm/ADT/GenericConvergenceVerifierImpl.h
@@ -0,0 +1,203 @@
+//===- GenericConvergenceVerifierImpl.h -----------------------*- C++ -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+///
+/// A verifer for the static rules of convergence control tokens that works with
+/// both LLVM IR and MIR.
+///
+/// This template implementation resides in a separate file so that it does not
+/// get injected into every .cpp file that includes the generic header.
+///
+/// DO NOT INCLUDE THIS FILE WHEN MERELY USING CYCLEINFO.
+///
+/// This file should only be included by files that implement a
+/// specialization of the relevant templates. Currently these are:
+/// - llvm/lib/IR/Verifier.cpp
+/// - llvm/lib/CodeGen/MachineVerifier.cpp
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H
+#define LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H
+
+#include "llvm/ADT/GenericConvergenceVerifier.h"
+#include "llvm/ADT/GenericCycleInfo.h"
+#include "llvm/ADT/GenericSSAContext.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Intrinsics.h"
+
+using namespace llvm;
+
+#define Check(C, ...)                                                          \
+  do {                                                                         \
+    if (!(C)) {                                                                \
+      reportFailure(__VA_ARGS__);                                              \
+      return;                                                                  \
+    }                                                                          \
+  } while (false)
+
+#define CheckOrNull(C, ...)                                                    \
+  do {                                                                         \
+    if (!(C)) {                                                                \
+      reportFailure(__VA_ARGS__);                                              \
+      return {};                                                               \
+    }                                                                          \
+  } while (false)
+
+static bool isConvergenceControlIntrinsic(unsigned IntrinsicID) {
+  switch (IntrinsicID) {
+  default:
+    return false;
+  case Intrinsic::experimental_convergence_anchor:
+  case Intrinsic::experimental_convergence_entry:
+  case Intrinsic::experimental_convergence_loop:
+    return true;
+  }
+}
+
+namespace llvm {
+template <class ContextT> void GenericConvergenceVerifier<ContextT>::clear() {
+  Tokens.clear();
+  CI.clear();
+  ConvergenceKind = NoConvergence;
+}
+
+template <class ContextT>
+void GenericConvergenceVerifier<ContextT>::visit(const InstructionT &I) {
+  if (isControlledConvergent(I)) {
+    Check(isConvergent(I),
+          "Expected convergent attribute on a controlled convergent call.",
+          {Context.print(&I)});
+    Check(ConvergenceKind != UncontrolledConvergence,
+          "Cannot mix controlled and uncontrolled convergence in the same "
+          "function.",
+          {Context.print(&I)});
+    ConvergenceKind = ControlledConvergence;
+  } else if (isConvergent(I)) {
+    Check(ConvergenceKind != ControlledConvergence,
+          "Cannot mix controlled and uncontrolled convergence in the same "
+          "function.",
+          {Context.print(&I)});
+    ConvergenceKind = UncontrolledConvergence;
+  }
+}
+
+template <class ContextT>
+void GenericConvergenceVerifier<ContextT>::reportFailure(
+    const Twine &Message, ArrayRef<Printable> DumpedValues) {
+  FailureCB(Message);
+  for (auto V : DumpedValues) {
+    *OS << V << '\n';
+  }
+}
+
+template <class ContextT>
+void GenericConvergenceVerifier<ContextT>::verify(const DominatorTreeT &DT) {
+  assert(Context.getFunction());
+  const auto &F = *Context.getFunction();
+
+  DenseMap<const BlockT *, SmallVector<const InstructionT *, 8>> LiveTokenMap;
+  DenseMap<const CycleT *, const InstructionT *> CycleHearts;
+
+  // Just like the DominatorTree, compute the CycleInfo locally so that we
+  // can run the verifier outside of a pass manager and we don't rely on
+  // potentially out-dated analysis results.
+  CI.compute(const_cast<FunctionT &>(F));
+
+  auto checkToken = [&](const InstructionT *Token, const InstructionT *User,
+                        SmallVectorImpl<const InstructionT *> &LiveTokens) {
+    Check(llvm::is_contained(LiveTokens, Token),
+          "Convergence region is not well-nested.",
+          {Context.print(Token), Context.print(User)});
+    while (LiveTokens.back() != Token)
+      LiveTokens.pop_back();
+
+    // Check static rules about cycles.
+    auto *BB = User->getParent();
+    auto *BBCycle = CI.getCycle(BB);
+    if (!BBCycle)
+      return;
+
+    auto *DefBB = Token->getParent();
+    if (DefBB == BB || BBCycle->contains(DefBB)) {
+      // degenerate occurrence of a loop intrinsic
+      return;
+    }
+
+    Check(ContextT::getIntrinsicID(*User) ==
+              Intrinsic::experimental_convergence_loop,
+          "Convergence token used by an instruction other than "
+          "llvm.experimental.convergence.loop in a cycle that does "
+          "not contain the token's definition.",
+          {Context.print(User), CI.print(BBCycle)});
+
+    while (true) {
+      auto *Parent = BBCycle->getParentCycle();
+      if (!Parent || Parent->contains(DefBB))
+        break;
+      BBCycle = Parent;
+    };
+
+    Check(BBCycle->isReducible() && BB == BBCycle->getHeader(),
+          "Cycle heart must dominate all blocks in the cycle.",
+          {Context.print(User), Context.printAsOperand(BB), CI.print(BBCycle)});
+    Check(!CycleHearts.count(BBCycle),
+          "Two static convergence token uses in a cycle that does "
+          "not contain either token's definition.",
+          {Context.print(User), Context.print(CycleHearts[BBCycle]),
+           CI.print(BBCycle)});
+    CycleHearts[BBCycle] = User;
+  };
+
+  ReversePostOrderTraversal<const FunctionT *> RPOT(&F);
+  SmallVector<const InstructionT *, 8> LiveTokens;
+  for (auto *BB : RPOT) {
+    LiveTokens.clear();
+    auto LTIt = LiveTokenMap.find(BB);
+    if (LTIt != LiveTokenMap.end()) {
+      LiveTokens = std::move(LTIt->second);
+      LiveTokenMap.erase(LTIt);
+    }
+
+    for (auto &I : *BB) {
+      if (auto *Token = Tokens.lookup(&I))
+        checkToken(Token, &I, LiveTokens);
+      if (isConvergenceControlIntrinsic(ContextT::getIntrinsicID(I)))
+        LiveTokens.push_back(&I);
+    }
+
+    // Propagate token liveness
+    for (auto *Succ : successors(BB)) {
+      auto *SuccNode = DT.getNode(Succ);
+      auto LTIt = LiveTokenMap.find(Succ);
+      if (LTIt == LiveTokenMap.end()) {
+        // We're the first predecessor: all tokens which dominate the
+        // successor are live for now.
+        LTIt = LiveTokenMap.try_emplace(Succ).first;
+        for (auto LiveToken : LiveTokens) {
+          if (!DT.dominates(DT.getNode(LiveToken->getParent()), SuccNode))
+            break;
+          LTIt->second.push_back(LiveToken);
+        }
+      } else {
+        // Compute the intersection of live tokens.
+        auto It = llvm::partition(
+            LTIt->second, [&LiveTokens](const InstructionT *Token) {
+              return llvm::is_contained(LiveTokens, Token);
+            });
+        LTIt->second.erase(It, LTIt->second.end());
+      }
+    }
+  }
+}
+
+} // end namespace llvm
+
+#endif // LLVM_ADT_GENERICCONVERGENCEVERIFIERIMPL_H

diff  --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 45f5dc7d3fcd1d..6aa3a8b9b6e0b6 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -24,6 +24,10 @@ namespace llvm {
 template <typename, bool> class DominatorTreeBase;
 template <typename> class SmallVectorImpl;
 
+namespace Intrinsic {
+typedef unsigned ID;
+}
+
 // Specializations of this template should provide the types used by the
 // template GenericSSAContext below.
 template <typename _FunctionT> struct GenericSSATraits;
@@ -78,6 +82,8 @@ template <typename _FunctionT> class GenericSSAContext {
 
   const FunctionT *getFunction() const { return F; }
 
+  static Intrinsic::ID getIntrinsicID(const InstructionT &I);
+
   static void appendBlockDefs(SmallVectorImpl<ValueRefT> &defs, BlockT &block);
   static void appendBlockDefs(SmallVectorImpl<ConstValueRefT> &defs,
                               const BlockT &block);
@@ -91,6 +97,7 @@ template <typename _FunctionT> class GenericSSAContext {
   const BlockT *getDefBlock(ConstValueRefT value) const;
 
   Printable print(const BlockT *block) const;
+  Printable printAsOperand(const BlockT *BB) const;
   Printable print(const InstructionT *inst) const;
   Printable print(ConstValueRefT value) const;
 };

diff  --git a/llvm/include/llvm/IR/ConvergenceVerifier.h b/llvm/include/llvm/IR/ConvergenceVerifier.h
new file mode 100644
index 00000000000000..df4c495f6a66fd
--- /dev/null
+++ b/llvm/include/llvm/IR/ConvergenceVerifier.h
@@ -0,0 +1,28 @@
+//===- ConvergenceVerifier.h - Verify convergenctrl -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file declares the LLVM IR specialization of the
+/// GenericConvergenceVerifier template.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_CONVERGENCEVERIFIER_H
+#define LLVM_IR_CONVERGENCEVERIFIER_H
+
+#include "llvm/ADT/GenericConvergenceVerifier.h"
+#include "llvm/IR/SSAContext.h"
+
+namespace llvm {
+
+extern template class GenericConvergenceVerifier<SSAContext>;
+using ConvergenceVerifier = GenericConvergenceVerifier<SSAContext>;
+
+} // namespace llvm
+
+#endif // LLVM_IR_CONVERGENCEVERIFIER_H

diff  --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index 4311255ed76d04..e384187b6e8593 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/MachineSSAContext.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstr.h"
@@ -58,6 +59,13 @@ bool MachineSSAContext::isConstantOrUndefValuePhi(const MachineInstr &Phi) {
   return Phi.isConstantValuePHI();
 }
 
+template <>
+Intrinsic::ID MachineSSAContext::getIntrinsicID(const MachineInstr &MI) {
+  if (auto *GI = dyn_cast<GIntrinsic>(&MI))
+    return GI->getIntrinsicID();
+  return Intrinsic::not_intrinsic;
+}
+
 template <>
 Printable MachineSSAContext::print(const MachineBasicBlock *Block) const {
   if (!Block)
@@ -83,3 +91,8 @@ template <> Printable MachineSSAContext::print(Register Value) const {
     }
   });
 }
+
+template <>
+Printable MachineSSAContext::printAsOperand(const MachineBasicBlock *BB) const {
+  return Printable([BB](raw_ostream &Out) { BB->printAsOperand(Out); });
+}

diff  --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 217fe703dd4eef..d9656a24d0ed3f 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_llvm_component_library(LLVMCore
   ConstantFold.cpp
   ConstantRange.cpp
   Constants.cpp
+  ConvergenceVerifier.cpp
   Core.cpp
   CycleInfo.cpp
   DIBuilder.cpp

diff  --git a/llvm/lib/IR/ConvergenceVerifier.cpp b/llvm/lib/IR/ConvergenceVerifier.cpp
new file mode 100644
index 00000000000000..e1c15a3a840837
--- /dev/null
+++ b/llvm/lib/IR/ConvergenceVerifier.cpp
@@ -0,0 +1,83 @@
+//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// NOTE: Including the following header causes a premature instantiation of the
+// template, and the compiler complains about explicit specialization after
+// instantiation. So don't include it.
+//
+// #include "llvm/IR/ConvergenceVerifier.h"
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/GenericConvergenceVerifierImpl.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/SSAContext.h"
+
+using namespace llvm;
+
+template <>
+const Instruction *
+GenericConvergenceVerifier<SSAContext>::findAndCheckConvergenceTokenUsed(
+    const Instruction &I) {
+  auto *CB = dyn_cast<CallBase>(&I);
+  if (!CB)
+    return nullptr;
+
+  unsigned Count =
+      CB->countOperandBundlesOfType(LLVMContext::OB_convergencectrl);
+  CheckOrNull(Count <= 1,
+              "The 'convergencetrl' bundle can occur at most once on a call",
+              {Context.print(CB)});
+  if (!Count)
+    return nullptr;
+
+  auto Bundle = CB->getOperandBundle(LLVMContext::OB_convergencectrl);
+  CheckOrNull(Bundle->Inputs.size() == 1 &&
+                  Bundle->Inputs[0]->getType()->isTokenTy(),
+              "The 'convergencectrl' bundle requires exactly one token use.",
+              {Context.print(CB)});
+  auto *Token = Bundle->Inputs[0].get();
+  auto *Def = dyn_cast<Instruction>(Token);
+
+  CheckOrNull(
+      Def && isConvergenceControlIntrinsic(SSAContext::getIntrinsicID(*Def)),
+      "Convergence control tokens can only be produced by calls to the "
+      "convergence control intrinsics.",
+      {Context.print(Token), Context.print(&I)});
+
+  if (Def)
+    Tokens[&I] = Def;
+
+  return Def;
+}
+
+template <>
+bool GenericConvergenceVerifier<SSAContext>::isConvergent(
+    const InstructionT &I) const {
+  if (auto *CB = dyn_cast<CallBase>(&I)) {
+    return CB->isConvergent();
+  }
+  return false;
+}
+
+template <>
+bool GenericConvergenceVerifier<SSAContext>::isControlledConvergent(
+    const InstructionT &I) {
+  // First find a token and place it in the map.
+  if (findAndCheckConvergenceTokenUsed(I))
+    return true;
+
+  // The entry and anchor intrinsics do not use a token, so we do a broad check
+  // here. The loop intrinsic will be checked separately for a missing token.
+  if (isConvergenceControlIntrinsic(SSAContext::getIntrinsicID(I)))
+    return true;
+
+  return false;
+}
+
+template class llvm::GenericConvergenceVerifier<SSAContext>;

diff  --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 3a5c4bf4aa30c0..220abe3083ebd7 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -16,10 +16,10 @@
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
-#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/ModuleSlotTracker.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 
@@ -69,6 +69,12 @@ bool SSAContext::isConstantOrUndefValuePhi(const Instruction &Instr) {
   return false;
 }
 
+template <> Intrinsic::ID SSAContext::getIntrinsicID(const Instruction &I) {
+  if (auto *CB = dyn_cast<CallBase>(&I))
+    return CB->getIntrinsicID();
+  return Intrinsic::not_intrinsic;
+}
+
 template <> Printable SSAContext::print(const Value *V) const {
   return Printable([V](raw_ostream &Out) { V->print(Out); });
 }
@@ -89,3 +95,7 @@ template <> Printable SSAContext::print(const BasicBlock *BB) const {
     Out << MST.getLocalSlot(BB);
   });
 }
+
+template <> Printable SSAContext::printAsOperand(const BasicBlock *BB) const {
+  return Printable([BB](raw_ostream &Out) { BB->printAsOperand(Out); });
+}

diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 21acb4185f7a8e..e03e6a117a243d 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -73,7 +73,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
-#include "llvm/IR/CycleInfo.h"
+#include "llvm/IR/ConvergenceVerifier.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DebugInfoMetadata.h"
@@ -329,13 +329,6 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   /// The current source language.
   dwarf::SourceLanguage CurrentSourceLang = dwarf::DW_LANG_lo_user;
 
-  /// Whether the current function has convergencectrl operand bundles.
-  enum {
-    ControlledConvergence,
-    UncontrolledConvergence,
-    NoConvergence
-  } ConvergenceKind = NoConvergence;
-
   /// Whether source was present on the first DIFile encountered in each CU.
   DenseMap<const DICompileUnit *, bool> HasSourceDebugInfo;
 
@@ -370,6 +363,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   SmallVector<const DILocalVariable *, 16> DebugFnArgs;
 
   TBAAVerifier TBAAVerifyHelper;
+  ConvergenceVerifier CV;
 
   SmallVector<IntrinsicInst *, 4> NoAliasScopeDecls;
 
@@ -411,12 +405,19 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
       return false;
     }
 
+    auto FailureCB = [this](const Twine &Message) {
+      this->CheckFailed(Message);
+    };
+    CV.initialize(*OS, FailureCB, F);
+
     Broken = false;
     // FIXME: We strip const here because the inst visitor strips const.
     visit(const_cast<Function &>(F));
     verifySiblingFuncletUnwinds();
-    if (ConvergenceKind == ControlledConvergence)
-      verifyConvergenceControl(const_cast<Function &>(F));
+
+    if (CV.sawTokens())
+      CV.verify(DT);
+
     InstsInThisBlock.clear();
     DebugFnArgs.clear();
     LandingPadResultTy = nullptr;
@@ -424,7 +425,6 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
     SiblingFuncletInfo.clear();
     verifyNoAliasScopeDecl();
     NoAliasScopeDecls.clear();
-    ConvergenceKind = NoConvergence;
 
     return !Broken;
   }
@@ -600,7 +600,6 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   void verifyStatepoint(const CallBase &Call);
   void verifyFrameRecoverIndices();
   void verifySiblingFuncletUnwinds();
-  void verifyConvergenceControl(Function &F);
 
   void verifyFragmentExpression(const DbgVariableIntrinsic &I);
   template <typename ValueOrMetadata>
@@ -2535,138 +2534,6 @@ void Verifier::verifySiblingFuncletUnwinds() {
   }
 }
 
-static bool isConvergenceControlIntrinsic(const CallBase &Call) {
-  switch (Call.getIntrinsicID()) {
-  case Intrinsic::experimental_convergence_anchor:
-  case Intrinsic::experimental_convergence_entry:
-  case Intrinsic::experimental_convergence_loop:
-    return true;
-  default:
-    return false;
-  }
-}
-
-static bool isControlledConvergent(const CallBase &Call) {
-  if (Call.countOperandBundlesOfType(LLVMContext::OB_convergencectrl))
-    return true;
-  return isConvergenceControlIntrinsic(Call);
-}
-
-void Verifier::verifyConvergenceControl(Function &F) {
-  DenseMap<BasicBlock *, SmallVector<CallBase *, 8>> LiveTokenMap;
-  DenseMap<const Cycle *, const CallBase *> CycleHearts;
-
-  // Just like the DominatorTree, compute the CycleInfo locally so that we
-  // can run the verifier outside of a pass manager and we don't rely on
-  // potentially out-dated analysis results.
-  CycleInfo CI;
-  CI.compute(F);
-
-  auto checkBundle = [&](OperandBundleUse &Bundle, CallBase *CB,
-                         SmallVectorImpl<CallBase *> &LiveTokens) {
-    Check(Bundle.Inputs.size() == 1 && Bundle.Inputs[0]->getType()->isTokenTy(),
-          "The 'convergencectrl' bundle requires exactly one token use.", CB);
-
-    Value *Token = Bundle.Inputs[0].get();
-    auto *Def = dyn_cast<CallBase>(Token);
-    Check(Def && isConvergenceControlIntrinsic(*Def),
-          "Convergence control tokens can only be produced by calls to the "
-          "convergence control intrinsics.",
-          Token, CB);
-
-    Check(llvm::is_contained(LiveTokens, Token),
-          "Convergence region is not well-nested.", Token, CB);
-
-    while (LiveTokens.back() != Token)
-      LiveTokens.pop_back();
-
-    // Check static rules about cycles.
-    auto *BB = CB->getParent();
-    auto *BBCycle = CI.getCycle(BB);
-    if (!BBCycle)
-      return;
-
-    BasicBlock *DefBB = Def->getParent();
-    if (DefBB == BB || BBCycle->contains(DefBB)) {
-      // degenerate occurrence of a loop intrinsic
-      return;
-    }
-
-    auto *II = dyn_cast<IntrinsicInst>(CB);
-    Check(II &&
-              II->getIntrinsicID() == Intrinsic::experimental_convergence_loop,
-          "Convergence token used by an instruction other than "
-          "llvm.experimental.convergence.loop in a cycle that does "
-          "not contain the token's definition.",
-          CB, CI.print(BBCycle));
-
-    while (true) {
-      auto *Parent = BBCycle->getParentCycle();
-      if (!Parent || Parent->contains(DefBB))
-        break;
-      BBCycle = Parent;
-    };
-
-    Check(BBCycle->isReducible() && BB == BBCycle->getHeader(),
-          "Cycle heart must dominate all blocks in the cycle.", CB, BB,
-          CI.print(BBCycle));
-    Check(!CycleHearts.count(BBCycle),
-          "Two static convergence token uses in a cycle that does "
-          "not contain either token's definition.",
-          CB, CycleHearts[BBCycle], CI.print(BBCycle));
-    CycleHearts[BBCycle] = CB;
-  };
-
-  ReversePostOrderTraversal<Function *> RPOT(&F);
-  SmallVector<CallBase *, 8> LiveTokens;
-  for (BasicBlock *BB : RPOT) {
-    LiveTokens.clear();
-    auto LTIt = LiveTokenMap.find(BB);
-    if (LTIt != LiveTokenMap.end()) {
-      LiveTokens = std::move(LTIt->second);
-      LiveTokenMap.erase(LTIt);
-    }
-
-    for (Instruction &I : *BB) {
-      CallBase *CB = dyn_cast<CallBase>(&I);
-      if (!CB)
-        continue;
-
-      Check(CB->countOperandBundlesOfType(LLVMContext::OB_convergencectrl) <= 1,
-            "The 'convergencetrl' bundle can occur at most once on a call", CB);
-
-      auto Bundle = CB->getOperandBundle(LLVMContext::OB_convergencectrl);
-      if (Bundle)
-        checkBundle(*Bundle, CB, LiveTokens);
-
-      if (CB->getType()->isTokenTy())
-        LiveTokens.push_back(CB);
-    }
-
-    // Propagate token liveness
-    for (BasicBlock *Succ : successors(BB)) {
-      DomTreeNode *SuccNode = DT.getNode(Succ);
-      LTIt = LiveTokenMap.find(Succ);
-      if (LTIt == LiveTokenMap.end()) {
-        // We're the first predecessor: all tokens which dominate the
-        // successor are live for now.
-        LTIt = LiveTokenMap.try_emplace(Succ).first;
-        for (CallBase *LiveToken : LiveTokens) {
-          if (!DT.dominates(DT.getNode(LiveToken->getParent()), SuccNode))
-            break;
-          LTIt->second.push_back(LiveToken);
-        }
-      } else {
-        // Compute the intersection of live tokens.
-        auto It = llvm::partition(LTIt->second, [&LiveTokens](CallBase *Token) {
-          return llvm::is_contained(LiveTokens, Token);
-        });
-        LTIt->second.erase(It, LTIt->second.end());
-      }
-    }
-  }
-}
-
 // visitFunction - Verify that a function is ok.
 //
 void Verifier::visitFunction(const Function &F) {
@@ -3688,22 +3555,7 @@ void Verifier::visitCallBase(CallBase &Call) {
   if (Call.isInlineAsm())
     verifyInlineAsmCall(Call);
 
-  if (isControlledConvergent(Call)) {
-    Check(Call.isConvergent(),
-          "Expected convergent attribute on a controlled convergent call.",
-          Call);
-    Check(ConvergenceKind != UncontrolledConvergence,
-          "Cannot mix controlled and uncontrolled convergence in the same "
-          "function.",
-          Call);
-    ConvergenceKind = ControlledConvergence;
-  } else if (Call.isConvergent()) {
-    Check(ConvergenceKind != ControlledConvergence,
-          "Cannot mix controlled and uncontrolled convergence in the same "
-          "function.",
-          Call);
-    ConvergenceKind = UncontrolledConvergence;
-  }
+  CV.visit(Call);
 
   visitInstruction(Call);
 }


        


More information about the llvm-commits mailing list