[llvm] b14e30f - [LLVM] refactor GenericSSAContext and its specializations

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 26 21:25:26 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-07-27T09:54:50+05:30
New Revision: b14e30f10dafb67101c5fbf29877226b7ea803e5

URL: https://github.com/llvm/llvm-project/commit/b14e30f10dafb67101c5fbf29877226b7ea803e5
DIFF: https://github.com/llvm/llvm-project/commit/b14e30f10dafb67101c5fbf29877226b7ea803e5.diff

LOG: [LLVM] refactor GenericSSAContext and its specializations

Fix the GenericSSAContext template so that it actually declares all the
necessary typenames and the methods that must be implemented by its
specializations SSAContext and MachineSSAContext.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D156288

Added: 
    

Modified: 
    llvm/include/llvm/ADT/GenericCycleImpl.h
    llvm/include/llvm/ADT/GenericCycleInfo.h
    llvm/include/llvm/ADT/GenericSSAContext.h
    llvm/include/llvm/ADT/GenericUniformityImpl.h
    llvm/include/llvm/ADT/GenericUniformityInfo.h
    llvm/include/llvm/CodeGen/MachineSSAContext.h
    llvm/include/llvm/IR/SSAContext.h
    llvm/lib/Analysis/UniformityAnalysis.cpp
    llvm/lib/CodeGen/MachineSSAContext.cpp
    llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
    llvm/lib/IR/SSAContext.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/GenericCycleImpl.h b/llvm/include/llvm/ADT/GenericCycleImpl.h
index c9e0772c2464ea..11dda696cf25f8 100644
--- a/llvm/include/llvm/ADT/GenericCycleImpl.h
+++ b/llvm/include/llvm/ADT/GenericCycleImpl.h
@@ -354,11 +354,11 @@ template <typename ContextT> void GenericCycleInfo<ContextT>::clear() {
 template <typename ContextT>
 void GenericCycleInfo<ContextT>::compute(FunctionT &F) {
   GenericCycleInfoCompute<ContextT> Compute(*this);
-  Context.setFunction(F);
+  Context = ContextT(&F);
 
   LLVM_DEBUG(errs() << "Computing cycles for function: " << F.getName()
                     << "\n");
-  Compute.run(ContextT::getEntryBlock(F));
+  Compute.run(&F.front());
 
   assert(validateTree());
 }

diff  --git a/llvm/include/llvm/ADT/GenericCycleInfo.h b/llvm/include/llvm/ADT/GenericCycleInfo.h
index 51ea7ed9a49822..3489872f2e6f1b 100644
--- a/llvm/include/llvm/ADT/GenericCycleInfo.h
+++ b/llvm/include/llvm/ADT/GenericCycleInfo.h
@@ -256,7 +256,7 @@ template <typename ContextT> class GenericCycleInfo {
   void clear();
   void compute(FunctionT &F);
 
-  FunctionT *getFunction() const { return Context.getFunction(); }
+  const FunctionT *getFunction() const { return Context.getFunction(); }
   const ContextT &getSSAContext() const { return Context; }
 
   CycleT *getCycle(const BlockT *Block) const;

diff  --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 929fd1442750fb..45f5dc7d3fcd1d 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -21,87 +21,78 @@
 
 namespace llvm {
 
+template <typename, bool> class DominatorTreeBase;
+template <typename> class SmallVectorImpl;
+
+// Specializations of this template should provide the types used by the
+// template GenericSSAContext below.
+template <typename _FunctionT> struct GenericSSATraits;
+
+// Ideally this should have been a stateless traits class. But the print methods
+// for Machine IR need access to the owning function. So we track that state in
+// the template itself.
+//
+// We use FunctionT as a template argument and not GenericSSATraits to allow
+// forward declarations using well-known typenames.
 template <typename _FunctionT> class GenericSSAContext {
-public:
-  // Specializations should provide the following types that are similar to how
-  // LLVM IR is structured:
+  using SSATraits = GenericSSATraits<_FunctionT>;
+  const typename SSATraits::FunctionT *F;
 
+public:
   // The smallest unit of the IR is a ValueT. The SSA context uses a ValueRefT,
   // which is a pointer to a ValueT, since Machine IR does not have the
   // equivalent of a ValueT.
-  //
-  // using ValueRefT = ...
-  //
+  using ValueRefT = typename SSATraits::ValueRefT;
+
   // The ConstValueRefT is needed to work with "const Value *", where const
   // needs to bind to the pointee and not the pointer.
-  //
-  // using ConstValueRefT = ...
-  //
-  // The null value for ValueRefT.
-  //
-  // static constexpr ValueRefT ValueRefNull;
+  using ConstValueRefT = typename SSATraits::ConstValueRefT;
+
+  // The null value for ValueRefT. For LLVM IR and MIR, this is simply the
+  // default constructed value.
+  static constexpr ValueRefT *ValueRefNull = {};
 
   // An InstructionT usually defines one or more ValueT objects.
-  //
-  // using InstructionT = ... must be a subclass of Value
+  using InstructionT = typename SSATraits::InstructionT;
 
   // A UseT represents a data-edge from the defining instruction to the using
   // instruction.
-  //
-  // using UseT = ...
+  using UseT = typename SSATraits::UseT;
 
   // A BlockT is a sequence of InstructionT, and forms a node of the CFG. It
   // has global methods predecessors() and successors() that return
   // the list of incoming CFG edges and outgoing CFG edges
   // respectively.
-  //
-  // using BlockT = ...
+  using BlockT = typename SSATraits::BlockT;
 
   // A FunctionT represents a CFG along with arguments and return values. It is
   // the smallest complete unit of code in a Module.
-  //
-  // The compiler produces an error here if this class is implicitly
-  // specialized due to an instantiation. An explicit specialization
-  // of this template needs to be added before the instantiation point
-  // indicated by the compiler.
-  using FunctionT = typename _FunctionT::invalidTemplateInstanceError;
+  using FunctionT = typename SSATraits::FunctionT;
 
   // A dominator tree provides the dominance relation between basic blocks in
   // a given funciton.
-  //
-  // using DominatorTreeT = ...
-
-  // Initialize the SSA context with information about the FunctionT being
-  // processed.
-  //
-  // void setFunction(FunctionT &function);
-  // FunctionT* getFunction() const;
-
-  // Every FunctionT has a unique BlockT marked as its entry.
-  //
-  // static BlockT* getEntryBlock(FunctionT &F);
-
-  // Methods to examine basic blocks and values
-  //
-  // static void appendBlockDefs(SmallVectorImpl<ValueRefT> &defs,
-  //                             BlockT &block);
-  // static void appendBlockDefs(SmallVectorImpl<const ValueRefT> &defs,
-  //                             const BlockT &block);
-
-  // static void appendBlockTerms(SmallVectorImpl<InstructionT *> &terms,
-  //                              BlockT &block);
-  // static void appendBlockTerms(SmallVectorImpl<const InstructionT *> &terms,
-  //                              const BlockT &block);
-  //
-  // static bool comesBefore(const InstructionT *lhs, const InstructionT *rhs);
-  // static bool isConstantOrUndefValuePhi(const InstructionT &Instr);
-  // const BlockT *getDefBlock(const ValueRefT value) const;
-
-  // Methods to print various objects.
-  //
-  // Printable print(BlockT *block) const;
-  // Printable print(InstructionT *inst) const;
-  // Printable print(ValueRefT value) const;
+  using DominatorTreeT = DominatorTreeBase<BlockT, false>;
+
+  GenericSSAContext() = default;
+  GenericSSAContext(const FunctionT *F) : F(F) {}
+
+  const FunctionT *getFunction() const { return F; }
+
+  static void appendBlockDefs(SmallVectorImpl<ValueRefT> &defs, BlockT &block);
+  static void appendBlockDefs(SmallVectorImpl<ConstValueRefT> &defs,
+                              const BlockT &block);
+
+  static void appendBlockTerms(SmallVectorImpl<InstructionT *> &terms,
+                               BlockT &block);
+  static void appendBlockTerms(SmallVectorImpl<const InstructionT *> &terms,
+                               const BlockT &block);
+
+  static bool isConstantOrUndefValuePhi(const InstructionT &Instr);
+  const BlockT *getDefBlock(ConstValueRefT value) const;
+
+  Printable print(const BlockT *block) const;
+  Printable print(const InstructionT *inst) const;
+  Printable print(ConstValueRefT value) const;
 };
 } // namespace llvm
 

diff  --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index 4df04accc6835d..2d5f6ee037ef7b 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -129,11 +129,11 @@ template <typename ContextT> class ModifiedPostOrder {
   const ContextT &Context;
 
   void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
-                      SmallPtrSetImpl<BlockT *> &Finalized);
+                      SmallPtrSetImpl<const BlockT *> &Finalized);
 
-  void computeStackPO(SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI,
-                      const CycleT *Cycle,
-                      SmallPtrSetImpl<BlockT *> &Finalized);
+  void computeStackPO(SmallVectorImpl<const BlockT *> &Stack,
+                      const CycleInfoT &CI, const CycleT *Cycle,
+                      SmallPtrSetImpl<const BlockT *> &Finalized);
 };
 
 template <typename> class DivergencePropagator;
@@ -342,11 +342,10 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
       typename SyncDependenceAnalysisT::DivergenceDescriptor;
   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
 
-  GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT,
-                                const CycleInfoT &CI,
+  GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
                                 const TargetTransformInfo *TTI)
-      : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT),
-        SDA(Context, DT, CI) {}
+      : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
+        TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
 
   void initialize();
 
@@ -1135,10 +1134,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
 
 template <typename ContextT>
 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
-    FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI,
-    const TargetTransformInfo *TTI)
-    : F(&Func) {
-  DA.reset(new ImplT{Func, DT, CI, TTI});
+    const DominatorTreeT &DT, const CycleInfoT &CI,
+    const TargetTransformInfo *TTI) {
+  DA.reset(new ImplT{DT, CI, TTI});
 }
 
 template <typename ContextT>
@@ -1214,6 +1212,12 @@ bool GenericUniformityInfo<ContextT>::hasDivergence() const {
   return DA->hasDivergence();
 }
 
+template <typename ContextT>
+const typename ContextT::FunctionT &
+GenericUniformityInfo<ContextT>::getFunction() const {
+  return DA->getFunction();
+}
+
 /// Whether \p V is divergent at its definition.
 template <typename ContextT>
 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
@@ -1243,8 +1247,8 @@ void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
 
 template <typename ContextT>
 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
-    SmallVectorImpl<BlockT *> &Stack, const CycleInfoT &CI, const CycleT *Cycle,
-    SmallPtrSetImpl<BlockT *> &Finalized) {
+    SmallVectorImpl<const BlockT *> &Stack, const CycleInfoT &CI,
+    const CycleT *Cycle, SmallPtrSetImpl<const BlockT *> &Finalized) {
   LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
   while (!Stack.empty()) {
     auto *NextBB = Stack.back();
@@ -1313,9 +1317,9 @@ void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
 template <typename ContextT>
 void ModifiedPostOrder<ContextT>::computeCyclePO(
     const CycleInfoT &CI, const CycleT *Cycle,
-    SmallPtrSetImpl<BlockT *> &Finalized) {
+    SmallPtrSetImpl<const BlockT *> &Finalized) {
   LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
-  SmallVector<BlockT *> Stack;
+  SmallVector<const BlockT *> Stack;
   auto *CycleHeader = Cycle->getHeader();
 
   LLVM_DEBUG(dbgs() << "  noted header: "
@@ -1352,11 +1356,11 @@ void ModifiedPostOrder<ContextT>::computeCyclePO(
 /// \brief Generically compute the modified post order.
 template <typename ContextT>
 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
-  SmallPtrSet<BlockT *, 32> Finalized;
-  SmallVector<BlockT *> Stack;
+  SmallPtrSet<const BlockT *, 32> Finalized;
+  SmallVector<const BlockT *> Stack;
   auto *F = CI.getFunction();
   Stack.reserve(24); // FIXME made-up number
-  Stack.push_back(GraphTraits<FunctionT *>::getEntryNode(F));
+  Stack.push_back(&F->front());
   computeStackPO(Stack, CI, nullptr, Finalized);
 }
 

diff  --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 114fdfed765c21..e53afccc020b46 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -40,8 +40,7 @@ template <typename ContextT> class GenericUniformityInfo {
   using CycleInfoT = GenericCycleInfo<ContextT>;
   using CycleT = typename CycleInfoT::CycleT;
 
-  GenericUniformityInfo(FunctionT &F, const DominatorTreeT &DT,
-                        const CycleInfoT &CI,
+  GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
                         const TargetTransformInfo *TTI = nullptr);
   GenericUniformityInfo() = default;
   GenericUniformityInfo(GenericUniformityInfo &&) = default;
@@ -56,7 +55,7 @@ template <typename ContextT> class GenericUniformityInfo {
   bool hasDivergence() const;
 
   /// The GPU kernel this analysis result is for
-  const FunctionT &getFunction() const { return *F; }
+  const FunctionT &getFunction() const;
 
   /// Whether \p V is divergent at its definition.
   bool isDivergent(ConstValueRefT V) const;
@@ -82,7 +81,6 @@ template <typename ContextT> class GenericUniformityInfo {
 private:
   using ImplT = GenericUniformityAnalysisImpl<ContextT>;
 
-  FunctionT *F;
   std::unique_ptr<ImplT, GenericUniformityAnalysisImplDeleter<ImplT>> DA;
 
   GenericUniformityInfo(const GenericUniformityInfo &) = delete;

diff  --git a/llvm/include/llvm/CodeGen/MachineSSAContext.h b/llvm/include/llvm/CodeGen/MachineSSAContext.h
index 2409c83071e16c..f2f733023dfe38 100644
--- a/llvm/include/llvm/CodeGen/MachineSSAContext.h
+++ b/llvm/include/llvm/CodeGen/MachineSSAContext.h
@@ -15,6 +15,7 @@
 #ifndef LLVM_CODEGEN_MACHINESSACONTEXT_H
 #define LLVM_CODEGEN_MACHINESSACONTEXT_H
 
+#include "llvm/ADT/GenericSSAContext.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/Support/Printable.h"
 
@@ -23,8 +24,6 @@ class MachineRegisterInfo;
 class MachineInstr;
 class MachineFunction;
 class Register;
-template <typename _FunctionT> class GenericSSAContext;
-template <typename, bool> class DominatorTreeBase;
 
 inline unsigned succ_size(const MachineBasicBlock *BB) {
   return BB->succ_size();
@@ -34,37 +33,13 @@ inline unsigned pred_size(const MachineBasicBlock *BB) {
 }
 inline auto instrs(const MachineBasicBlock &BB) { return BB.instrs(); }
 
-template <> class GenericSSAContext<MachineFunction> {
-  const MachineRegisterInfo *RegInfo = nullptr;
-  MachineFunction *MF = nullptr;
-
-public:
+template <> struct GenericSSATraits<MachineFunction> {
   using BlockT = MachineBasicBlock;
   using FunctionT = MachineFunction;
   using InstructionT = MachineInstr;
   using ValueRefT = Register;
   using ConstValueRefT = Register;
   using UseT = MachineOperand;
-  using DominatorTreeT = DominatorTreeBase<BlockT, false>;
-
-  static constexpr Register ValueRefNull = 0;
-
-  void setFunction(MachineFunction &Fn);
-  MachineFunction *getFunction() const { return MF; }
-
-  static MachineBasicBlock *getEntryBlock(MachineFunction &F);
-  static void appendBlockDefs(SmallVectorImpl<Register> &defs,
-                              const MachineBasicBlock &block);
-  static void appendBlockTerms(SmallVectorImpl<MachineInstr *> &terms,
-                               MachineBasicBlock &block);
-  static void appendBlockTerms(SmallVectorImpl<const MachineInstr *> &terms,
-                               const MachineBasicBlock &block);
-  MachineBasicBlock *getDefBlock(Register) const;
-  static bool isConstantOrUndefValuePhi(const MachineInstr &Phi);
-
-  Printable print(const MachineBasicBlock *Block) const;
-  Printable print(const MachineInstr *Inst) const;
-  Printable print(Register Value) const;
 };
 
 using MachineSSAContext = GenericSSAContext<MachineFunction>;

diff  --git a/llvm/include/llvm/IR/SSAContext.h b/llvm/include/llvm/IR/SSAContext.h
index 557ec752c21645..d0da5e222a8e63 100644
--- a/llvm/include/llvm/IR/SSAContext.h
+++ b/llvm/include/llvm/IR/SSAContext.h
@@ -17,60 +17,24 @@
 
 #include "llvm/ADT/GenericSSAContext.h"
 #include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/ModuleSlotTracker.h"
-#include "llvm/Support/Printable.h"
-
-#include <memory>
 
 namespace llvm {
 class BasicBlock;
 class Function;
 class Instruction;
 class Value;
-template <typename> class SmallVectorImpl;
-template <typename, bool> class DominatorTreeBase;
 
 inline auto instrs(const BasicBlock &BB) {
   return llvm::make_range(BB.begin(), BB.end());
 }
 
-template <> class GenericSSAContext<Function> {
-  Function *F;
-
-public:
+template <> struct GenericSSATraits<Function> {
   using BlockT = BasicBlock;
   using FunctionT = Function;
   using InstructionT = Instruction;
   using ValueRefT = Value *;
   using ConstValueRefT = const Value *;
   using UseT = Use;
-  using DominatorTreeT = DominatorTreeBase<BlockT, false>;
-
-  static constexpr Value *ValueRefNull = nullptr;
-
-  void setFunction(Function &Fn);
-  Function *getFunction() const { return F; }
-
-  static BasicBlock *getEntryBlock(Function &F);
-  static const BasicBlock *getEntryBlock(const Function &F);
-
-  static void appendBlockDefs(SmallVectorImpl<Value *> &defs,
-                              BasicBlock &block);
-  static void appendBlockDefs(SmallVectorImpl<const Value *> &defs,
-                              const BasicBlock &block);
-
-  static void appendBlockTerms(SmallVectorImpl<Instruction *> &terms,
-                               BasicBlock &block);
-  static void appendBlockTerms(SmallVectorImpl<const Instruction *> &terms,
-                               const BasicBlock &block);
-
-  static bool comesBefore(const Instruction *lhs, const Instruction *rhs);
-  static bool isConstantOrUndefValuePhi(const Instruction &Instr);
-  const BasicBlock *getDefBlock(const Value *value) const;
-
-  Printable print(const BasicBlock *Block) const;
-  Printable print(const Instruction *Inst) const;
-  Printable print(const Value *Value) const;
 };
 
 using SSAContext = GenericSSAContext<Function>;

diff  --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index bf0b194dcd708c..2d617db431c588 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -118,7 +118,7 @@ llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
   auto &CI = FAM.getResult<CycleAnalysis>(F);
-  UniformityInfo UI{F, DT, CI, &TTI};
+  UniformityInfo UI{DT, CI, &TTI};
   // Skip computation if we can assume everything is uniform.
   if (TTI.hasBranchDivergence(&F))
     UI.compute();
@@ -171,8 +171,7 @@ bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
   m_function = &F;
-  m_uniformityInfo =
-      UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
+  m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
 
   // Skip computation if we can assume everything is uniform.
   if (targetTransformInfo.hasBranchDivergence(m_function))

diff  --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index 324084fb9c3237..4311255ed76d04 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -21,15 +21,23 @@
 
 using namespace llvm;
 
-void MachineSSAContext::setFunction(MachineFunction &Fn) {
-  MF = &Fn;
-  RegInfo = &MF->getRegInfo();
+template <>
+void MachineSSAContext::appendBlockDefs(SmallVectorImpl<Register> &defs,
+                                        const MachineBasicBlock &block) {
+  for (auto &instr : block.instrs()) {
+    for (auto &op : instr.all_defs())
+      defs.push_back(op.getReg());
+  }
 }
 
-MachineBasicBlock *MachineSSAContext::getEntryBlock(MachineFunction &F) {
-  return &F.front();
+template <>
+void MachineSSAContext::appendBlockTerms(SmallVectorImpl<MachineInstr *> &terms,
+                                         MachineBasicBlock &block) {
+  for (auto &T : block.terminators())
+    terms.push_back(&T);
 }
 
+template <>
 void MachineSSAContext::appendBlockTerms(
     SmallVectorImpl<const MachineInstr *> &terms,
     const MachineBasicBlock &block) {
@@ -37,37 +45,32 @@ void MachineSSAContext::appendBlockTerms(
     terms.push_back(&T);
 }
 
-void MachineSSAContext::appendBlockDefs(SmallVectorImpl<Register> &defs,
-                                        const MachineBasicBlock &block) {
-  for (const MachineInstr &instr : block.instrs()) {
-    for (const MachineOperand &op : instr.all_defs())
-      defs.push_back(op.getReg());
-  }
-}
-
 /// Get the defining block of a value.
-MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const {
+template <>
+const MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const {
   if (!value)
     return nullptr;
-  return RegInfo->getVRegDef(value)->getParent();
+  return F->getRegInfo().getVRegDef(value)->getParent();
 }
 
+template <>
 bool MachineSSAContext::isConstantOrUndefValuePhi(const MachineInstr &Phi) {
   return Phi.isConstantValuePHI();
 }
 
+template <>
 Printable MachineSSAContext::print(const MachineBasicBlock *Block) const {
   if (!Block)
     return Printable([](raw_ostream &Out) { Out << "<nullptr>"; });
   return Printable([Block](raw_ostream &Out) { Block->printName(Out); });
 }
 
-Printable MachineSSAContext::print(const MachineInstr *I) const {
+template <> Printable MachineSSAContext::print(const MachineInstr *I) const {
   return Printable([I](raw_ostream &Out) { I->print(Out); });
 }
 
-Printable MachineSSAContext::print(Register Value) const {
-  auto *MRI = RegInfo;
+template <> Printable MachineSSAContext::print(Register Value) const {
+  auto *MRI = &F->getRegInfo();
   return Printable([MRI, Value](raw_ostream &Out) {
     Out << printReg(Value, MRI->getTargetRegisterInfo(), 0, MRI);
 

diff  --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index 0e02c50284c60f..3e0fe2b1ba087f 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -157,7 +157,7 @@ MachineUniformityInfo llvm::computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
     const MachineDomTree &domTree, bool HasBranchDivergence) {
   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
-  MachineUniformityInfo UI(F, domTree, cycleInfo);
+  MachineUniformityInfo UI(domTree, cycleInfo);
   if (HasBranchDivergence)
     UI.compute();
   return UI;

diff  --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 4790d19b74b51a..3a5c4bf4aa30c0 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -19,31 +19,21 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/IR/ModuleSlotTracker.h"
 
 using namespace llvm;
 
-void SSAContext::setFunction(Function &Fn) { F = &Fn; }
-
-BasicBlock *SSAContext::getEntryBlock(Function &F) {
-  return &F.getEntryBlock();
-}
-
-const BasicBlock *SSAContext::getEntryBlock(const Function &F) {
-  return &F.getEntryBlock();
-}
-
+template <>
 void SSAContext::appendBlockDefs(SmallVectorImpl<Value *> &defs,
                                  BasicBlock &block) {
-  for (auto &instr : block.instructionsWithoutDebug(/*SkipPseudoOp=*/true)) {
+  for (auto &instr : block) {
     if (instr.isTerminator())
       break;
-    if (instr.getType()->isVoidTy())
-      continue;
-    auto *def = &instr;
-    defs.push_back(def);
+    defs.push_back(&instr);
   }
 }
 
+template <>
 void SSAContext::appendBlockDefs(SmallVectorImpl<const Value *> &defs,
                                  const BasicBlock &block) {
   for (auto &instr : block) {
@@ -53,41 +43,41 @@ void SSAContext::appendBlockDefs(SmallVectorImpl<const Value *> &defs,
   }
 }
 
+template <>
 void SSAContext::appendBlockTerms(SmallVectorImpl<Instruction *> &terms,
                                   BasicBlock &block) {
   terms.push_back(block.getTerminator());
 }
 
+template <>
 void SSAContext::appendBlockTerms(SmallVectorImpl<const Instruction *> &terms,
                                   const BasicBlock &block) {
   terms.push_back(block.getTerminator());
 }
 
+template <>
 const BasicBlock *SSAContext::getDefBlock(const Value *value) const {
   if (const auto *instruction = dyn_cast<Instruction>(value))
     return instruction->getParent();
   return nullptr;
 }
 
-bool SSAContext::comesBefore(const Instruction *lhs, const Instruction *rhs) {
-  return lhs->comesBefore(rhs);
-}
-
+template <>
 bool SSAContext::isConstantOrUndefValuePhi(const Instruction &Instr) {
   if (auto *Phi = dyn_cast<PHINode>(&Instr))
     return Phi->hasConstantOrUndefValue();
   return false;
 }
 
-Printable SSAContext::print(const Value *V) const {
+template <> Printable SSAContext::print(const Value *V) const {
   return Printable([V](raw_ostream &Out) { V->print(Out); });
 }
 
-Printable SSAContext::print(const Instruction *Inst) const {
+template <> Printable SSAContext::print(const Instruction *Inst) const {
   return print(cast<Value>(Inst));
 }
 
-Printable SSAContext::print(const BasicBlock *BB) const {
+template <> Printable SSAContext::print(const BasicBlock *BB) const {
   if (!BB)
     return Printable([](raw_ostream &Out) { Out << "<nullptr>"; });
   if (BB->hasName())


        


More information about the llvm-commits mailing list