[llvm] CFGPrinter: fix accidentally quadratic behavior (PR #125396)

Nicolai Hähnle via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 2 02:36:22 PST 2025


https://github.com/nhaehnle created https://github.com/llvm/llvm-project/pull/125396

Initialize a ModuleStateTracker at most once per BasicBlock instead of once per Instruction. When the CFG info is provided, it is initialized once per function.

>From 15e9fcfce820c0522f7af223f7ffdf1707ae7ee9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nicolai=20H=C3=A4hnle?= <nicolai.haehnle at amd.com>
Date: Sun, 2 Feb 2025 10:05:06 +0100
Subject: [PATCH] CFGPrinter: fix accidentally quadratic behavior

Initialize a ModuleStateTracker at most once per BasicBlock instead of once
per Instruction. When the CFG info is provided, it is initialized once
per function.
---
 llvm/include/llvm/Analysis/CFGPrinter.h | 31 ++++++---------
 llvm/lib/Analysis/CFGPrinter.cpp        | 52 +++++++++++++++++++++++++
 2 files changed, 63 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h
index cd785331d1f1468..b844e3f11c4a501 100644
--- a/llvm/include/llvm/Analysis/CFGPrinter.h
+++ b/llvm/include/llvm/Analysis/CFGPrinter.h
@@ -31,6 +31,8 @@
 #include "llvm/Support/FormatVariadic.h"
 
 namespace llvm {
+class ModuleSlotTracker;
+
 template <class GraphType> struct GraphTraits;
 class CFGViewerPass : public PassInfoMixin<CFGViewerPass> {
 public:
@@ -61,6 +63,7 @@ class DOTFuncInfo {
   const Function *F;
   const BlockFrequencyInfo *BFI;
   const BranchProbabilityInfo *BPI;
+  std::unique_ptr<ModuleSlotTracker> MSTStorage;
   uint64_t MaxFreq;
   bool ShowHeat;
   bool EdgeWeights;
@@ -68,14 +71,10 @@ class DOTFuncInfo {
 
 public:
   DOTFuncInfo(const Function *F) : DOTFuncInfo(F, nullptr, nullptr, 0) {}
+  ~DOTFuncInfo();
 
   DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI,
-              const BranchProbabilityInfo *BPI, uint64_t MaxFreq)
-      : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) {
-    ShowHeat = false;
-    EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available.
-    RawWeights = !!BFI;  // Print RawWeights when BFI is available.
-  }
+              const BranchProbabilityInfo *BPI, uint64_t MaxFreq);
 
   const BlockFrequencyInfo *getBFI() const { return BFI; }
 
@@ -83,6 +82,8 @@ class DOTFuncInfo {
 
   const Function *getFunction() const { return this->F; }
 
+  ModuleSlotTracker *getModuleSlotTracker();
+
   uint64_t getMaxFreq() const { return MaxFreq; }
 
   uint64_t getFreq(const BasicBlock *BB) const {
@@ -203,22 +204,12 @@ struct DOTGraphTraits<DOTFuncInfo *> : public DefaultDOTGraphTraits {
     return SimpleNodeLabelString(Node);
   }
 
-  static void printBasicBlock(raw_string_ostream &OS, const BasicBlock &Node) {
-    // Prepend label name
-    Node.printAsOperand(OS, false);
-    OS << ":\n";
-    for (const Instruction &Inst : Node)
-      OS << Inst << "\n";
-  }
-
   static std::string getCompleteNodeLabel(
       const BasicBlock *Node, DOTFuncInfo *,
       function_ref<void(raw_string_ostream &, const BasicBlock &)>
-          HandleBasicBlock = printBasicBlock,
-      function_ref<void(std::string &, unsigned &, unsigned)>
-          HandleComment = eraseComment) {
-    return CompleteNodeLabelString(Node, HandleBasicBlock, HandleComment);
-  }
+          HandleBasicBlock = {},
+      function_ref<void(std::string &, unsigned &, unsigned)> HandleComment =
+          eraseComment);
 
   std::string getNodeLabel(const BasicBlock *Node, DOTFuncInfo *CFGInfo) {
 
@@ -337,6 +328,6 @@ struct DOTGraphTraits<DOTFuncInfo *> : public DefaultDOTGraphTraits {
   bool isNodeHidden(const BasicBlock *Node, const DOTFuncInfo *CFGInfo);
   void computeDeoptOrUnreachablePaths(const Function *F);
 };
-} // End llvm namespace
+} // namespace llvm
 
 #endif
diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp
index af18fb6626e3bf8..e6902f22f75949b 100644
--- a/llvm/lib/Analysis/CFGPrinter.cpp
+++ b/llvm/lib/Analysis/CFGPrinter.cpp
@@ -19,6 +19,7 @@
 
 #include "llvm/Analysis/CFGPrinter.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/IR/ModuleSlotTracker.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/GraphWriter.h"
@@ -90,6 +91,22 @@ static void viewCFG(Function &F, const BlockFrequencyInfo *BFI,
   ViewGraph(&CFGInfo, "cfg." + F.getName(), CFGOnly);
 }
 
+DOTFuncInfo::DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI,
+                         const BranchProbabilityInfo *BPI, uint64_t MaxFreq)
+    : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) {
+  ShowHeat = false;
+  EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available.
+  RawWeights = !!BFI;  // Print RawWeights when BFI is available.
+}
+
+DOTFuncInfo::~DOTFuncInfo() = default;
+
+ModuleSlotTracker *DOTFuncInfo::getModuleSlotTracker() {
+  if (!MSTStorage)
+    MSTStorage = std::make_unique<ModuleSlotTracker>(F->getParent());
+  return &*MSTStorage;
+}
+
 PreservedAnalyses CFGViewerPass::run(Function &F, FunctionAnalysisManager &AM) {
   if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName))
     return PreservedAnalyses::all();
@@ -208,3 +225,38 @@ bool DOTGraphTraits<DOTFuncInfo *>::isNodeHidden(const BasicBlock *Node,
   }
   return false;
 }
+
+std::string DOTGraphTraits<DOTFuncInfo *>::getCompleteNodeLabel(
+    const BasicBlock *Node, DOTFuncInfo *CFGInfo,
+    function_ref<void(raw_string_ostream &, const BasicBlock &)>
+        HandleBasicBlock,
+    function_ref<void(std::string &, unsigned &, unsigned)> HandleComment) {
+  if (HandleBasicBlock)
+    return CompleteNodeLabelString(Node, HandleBasicBlock, HandleComment);
+
+  // Default basic block printing
+  std::optional<ModuleSlotTracker> MSTStorage;
+  ModuleSlotTracker *MST = nullptr;
+
+  if (CFGInfo) {
+    MST = CFGInfo->getModuleSlotTracker();
+  } else {
+    MSTStorage.emplace(Node->getModule());
+    MST = &*MSTStorage;
+  }
+
+  return CompleteNodeLabelString(
+      Node,
+      function_ref<void(raw_string_ostream &, const BasicBlock &)>(
+          [MST](raw_string_ostream &OS, const BasicBlock &Node) -> void {
+            // Prepend label name
+            Node.printAsOperand(OS, false);
+            OS << ":\n";
+
+            for (const Instruction &Inst : Node) {
+              Inst.print(OS, *MST, /* IsForDebug */ false);
+              OS << '\n';
+            }
+          }),
+      HandleComment);
+}



More information about the llvm-commits mailing list