[llvm] [NFC][LLVM][CodeGen] Refactor MIR Printer (PR #137361)

Rahul Joshi via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 27 12:29:46 PDT 2025


https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/137361

>From 39870c9da9bd07f4f3230e8219e66f78851bd4ad Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Fri, 25 Apr 2025 09:33:42 -0700
Subject: [PATCH] [NFC][LLVM][MI] Refactor MI Printer code.

- Move `MIPrinter` class to anonymous namespace, and remove it as a
  friend of `MachineBasicBlock`.
- Move `canPredictBranchProbabilities` to `MachineBasicBlock` and
  change it to use the new `BranchProbability::normalizeProbabilities`
  function that accepts a range, and also to use llvm::equal() to
  check equality of the two vectors.
- Use `ListSeparator` to print comma separate lists instead of manual
  code to do that. Also added `ListSeparator::reset()` to enable reuse
  of the same ListSeparator object for multiple lists.
---
 llvm/include/llvm/CodeGen/MachineBasicBlock.h |   4 +-
 llvm/include/llvm/Support/BranchProbability.h |   6 +
 llvm/lib/CodeGen/MIRPrinter.cpp               | 158 ++++++------------
 llvm/lib/CodeGen/MIRPrintingPass.cpp          |   6 +-
 llvm/lib/CodeGen/MachineBasicBlock.cpp        |  42 +++--
 5 files changed, 89 insertions(+), 127 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
index d2d90ad868d2d..9c563d761c1d9 100644
--- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h
+++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
@@ -1263,6 +1263,9 @@ class MachineBasicBlock
   /// MachineBranchProbabilityInfo class.
   BranchProbability getSuccProbability(const_succ_iterator Succ) const;
 
+  // Helper function for MIRPrinter.
+  bool canPredictBranchProbabilities() const;
+
 private:
   /// Return probability iterator corresponding to the I successor iterator.
   probability_iterator getProbabilityIterator(succ_iterator I);
@@ -1270,7 +1273,6 @@ class MachineBasicBlock
   getProbabilityIterator(const_succ_iterator I) const;
 
   friend class MachineBranchProbabilityInfo;
-  friend class MIPrinter;
 
   // Methods used to maintain doubly linked list of blocks...
   friend struct ilist_callback_traits<MachineBasicBlock>;
diff --git a/llvm/include/llvm/Support/BranchProbability.h b/llvm/include/llvm/Support/BranchProbability.h
index 79d70cf611d41..74731c416ac04 100644
--- a/llvm/include/llvm/Support/BranchProbability.h
+++ b/llvm/include/llvm/Support/BranchProbability.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
 
+#include "llvm/ADT/ADL.h"
 #include "llvm/Support/DataTypes.h"
 #include <algorithm>
 #include <cassert>
@@ -62,6 +63,11 @@ class BranchProbability {
   static void normalizeProbabilities(ProbabilityIter Begin,
                                      ProbabilityIter End);
 
+  template <class ProbabilityContainer>
+  static void normalizeProbabilities(ProbabilityContainer &&R) {
+    normalizeProbabilities(adl_begin(R), adl_end(R));
+  }
+
   uint32_t getNumerator() const { return N; }
   static uint32_t getDenominator() { return D; }
 
diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp
index 2f08fcda1fbd0..906048679553c 100644
--- a/llvm/lib/CodeGen/MIRPrinter.cpp
+++ b/llvm/lib/CodeGen/MIRPrinter.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/MIRYamlMapping.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
@@ -93,10 +94,6 @@ struct FrameIndexOperand {
   }
 };
 
-} // end anonymous namespace
-
-namespace llvm {
-
 /// This class prints out the machine functions using the MIR serialization
 /// format.
 class MIRPrinter {
@@ -151,7 +148,6 @@ class MIPrinter {
   /// Synchronization scope names registered with LLVMContext.
   SmallVector<StringRef, 8> SSNs;
 
-  bool canPredictBranchProbabilities(const MachineBasicBlock &MBB) const;
   bool canPredictSuccessors(const MachineBasicBlock &MBB) const;
 
 public:
@@ -167,14 +163,13 @@ class MIPrinter {
   void printStackObjectReference(int FrameIndex);
   void print(const MachineInstr &MI, unsigned OpIdx,
              const TargetRegisterInfo *TRI, const TargetInstrInfo *TII,
-             bool ShouldPrintRegisterTies, LLT TypeToPrint,
-             bool PrintDef = true);
+             bool ShouldPrintRegisterTies, SmallBitVector &PrintedTypes,
+             const MachineRegisterInfo &MRI, bool PrintDef = true);
 };
 
-} // end namespace llvm
+} // end anonymous namespace
 
-namespace llvm {
-namespace yaml {
+namespace llvm::yaml {
 
 /// This struct serializes the LLVM IR module.
 template <> struct BlockScalarTraits<Module> {
@@ -188,8 +183,7 @@ template <> struct BlockScalarTraits<Module> {
   }
 };
 
-} // end namespace yaml
-} // end namespace llvm
+} // end namespace llvm::yaml
 
 static void printRegMIR(Register Reg, yaml::StringValue &Dest,
                         const TargetRegisterInfo *TRI) {
@@ -327,9 +321,8 @@ static void printRegFlags(Register Reg,
                           const MachineFunction &MF,
                           const TargetRegisterInfo *TRI) {
   auto FlagValues = TRI->getVRegFlagsOfReg(Reg, MF);
-  for (auto &Flag : FlagValues) {
+  for (auto &Flag : FlagValues)
     RegisterFlags.push_back(yaml::FlowStringValue(Flag.str()));
-  }
 }
 
 void MIRPrinter::convert(yaml::MachineFunction &YamlMF,
@@ -618,9 +611,8 @@ void MIRPrinter::convertCalledGlobals(yaml::MachineFunction &YMF,
   // Sort by position of call instructions.
   llvm::sort(YMF.CalledGlobals.begin(), YMF.CalledGlobals.end(),
              [](yaml::CalledGlobal A, yaml::CalledGlobal B) {
-               if (A.CallSite.BlockNum == B.CallSite.BlockNum)
-                 return A.CallSite.Offset < B.CallSite.Offset;
-               return A.CallSite.BlockNum < B.CallSite.BlockNum;
+               return std::tie(A.CallSite.BlockNum, A.CallSite.Offset) <
+                      std::tie(B.CallSite.BlockNum, B.CallSite.Offset);
              });
 }
 
@@ -630,11 +622,10 @@ void MIRPrinter::convert(yaml::MachineFunction &MF,
   for (const MachineConstantPoolEntry &Constant : ConstantPool.getConstants()) {
     std::string Str;
     raw_string_ostream StrOS(Str);
-    if (Constant.isMachineConstantPoolEntry()) {
+    if (Constant.isMachineConstantPoolEntry())
       Constant.Val.MachineCPVal->print(StrOS);
-    } else {
+    else
       Constant.Val.ConstVal->printAsOperand(StrOS);
-    }
 
     yaml::MachineConstantPoolValue YamlConstant;
     YamlConstant.ID = ID++;
@@ -693,23 +684,6 @@ void llvm::guessSuccessors(const MachineBasicBlock &MBB,
   IsFallthrough = I == MBB.end() || !I->isBarrier();
 }
 
-bool
-MIPrinter::canPredictBranchProbabilities(const MachineBasicBlock &MBB) const {
-  if (MBB.succ_size() <= 1)
-    return true;
-  if (!MBB.hasSuccessorProbabilities())
-    return true;
-
-  SmallVector<BranchProbability,8> Normalized(MBB.Probs.begin(),
-                                              MBB.Probs.end());
-  BranchProbability::normalizeProbabilities(Normalized.begin(),
-                                            Normalized.end());
-  SmallVector<BranchProbability,8> Equal(Normalized.size());
-  BranchProbability::normalizeProbabilities(Equal.begin(), Equal.end());
-
-  return std::equal(Normalized.begin(), Normalized.end(), Equal.begin());
-}
-
 bool MIPrinter::canPredictSuccessors(const MachineBasicBlock &MBB) const {
   SmallVector<MachineBasicBlock*,8> GuessedSuccs;
   bool GuessedFallthrough;
@@ -738,7 +712,7 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
 
   bool HasLineAttributes = false;
   // Print the successors
-  bool canPredictProbs = canPredictBranchProbabilities(MBB);
+  bool canPredictProbs = MBB.canPredictBranchProbabilities();
   // Even if the list of successors is empty, if we cannot guess it,
   // we need to print it to tell the parser that the list is empty.
   // This is needed, because MI model unreachable as empty blocks
@@ -750,14 +724,12 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
     OS.indent(2) << "successors:";
     if (!MBB.succ_empty())
       OS << " ";
+    ListSeparator LS;
     for (auto I = MBB.succ_begin(), E = MBB.succ_end(); I != E; ++I) {
-      if (I != MBB.succ_begin())
-        OS << ", ";
-      OS << printMBBReference(**I);
+      OS << LS << printMBBReference(**I);
       if (!SimplifyMIR || !canPredictProbs)
-        OS << '('
-           << format("0x%08" PRIx32, MBB.getSuccProbability(I).getNumerator())
-           << ')';
+        OS << format("(0x%08" PRIx32 ")",
+                     MBB.getSuccProbability(I).getNumerator());
     }
     OS << "\n";
     HasLineAttributes = true;
@@ -768,12 +740,9 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
   if (!MBB.livein_empty()) {
     const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
     OS.indent(2) << "liveins: ";
-    bool First = true;
+    ListSeparator LS;
     for (const auto &LI : MBB.liveins_dbg()) {
-      if (!First)
-        OS << ", ";
-      First = false;
-      OS << printReg(LI.PhysReg, &TRI);
+      OS << LS << printReg(LI.PhysReg, &TRI);
       if (!LI.LaneMask.all())
         OS << ":0x" << PrintLaneMask(LI.LaneMask);
     }
@@ -814,14 +783,14 @@ void MIPrinter::print(const MachineInstr &MI) {
 
   SmallBitVector PrintedTypes(8);
   bool ShouldPrintRegisterTies = MI.hasComplexRegisterTies();
+  ListSeparator LS;
   unsigned I = 0, E = MI.getNumOperands();
-  for (; I < E && MI.getOperand(I).isReg() && MI.getOperand(I).isDef() &&
-         !MI.getOperand(I).isImplicit();
-       ++I) {
-    if (I)
-      OS << ", ";
-    print(MI, I, TRI, TII, ShouldPrintRegisterTies,
-          MI.getTypeToPrint(I, PrintedTypes, MRI),
+  for (; I < E; ++I) {
+    const MachineOperand MO = MI.getOperand(I);
+    if (!MO.isReg() || !MO.isDef() || MO.isImplicit())
+      break;
+    OS << LS;
+    print(MI, I, TRI, TII, ShouldPrintRegisterTies, PrintedTypes, MRI,
           /*PrintDef=*/false);
   }
 
@@ -869,74 +838,48 @@ void MIPrinter::print(const MachineInstr &MI) {
     OS << "samesign ";
 
   OS << TII->getName(MI.getOpcode());
-  if (I < E)
-    OS << ' ';
 
-  bool NeedComma = false;
-  for (; I < E; ++I) {
-    if (NeedComma)
-      OS << ", ";
-    print(MI, I, TRI, TII, ShouldPrintRegisterTies,
-          MI.getTypeToPrint(I, PrintedTypes, MRI));
-    NeedComma = true;
+  LS = ListSeparator();
+
+  if (I < E) {
+    OS << ' ';
+    for (; I < E; ++I) {
+      OS << LS;
+      print(MI, I, TRI, TII, ShouldPrintRegisterTies, PrintedTypes, MRI);
+    }
   }
 
   // Print any optional symbols attached to this instruction as-if they were
   // operands.
   if (MCSymbol *PreInstrSymbol = MI.getPreInstrSymbol()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " pre-instr-symbol ";
+    OS << LS << " pre-instr-symbol ";
     MachineOperand::printSymbol(OS, *PreInstrSymbol);
-    NeedComma = true;
   }
   if (MCSymbol *PostInstrSymbol = MI.getPostInstrSymbol()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " post-instr-symbol ";
+    OS << LS << " post-instr-symbol ";
     MachineOperand::printSymbol(OS, *PostInstrSymbol);
-    NeedComma = true;
   }
   if (MDNode *HeapAllocMarker = MI.getHeapAllocMarker()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " heap-alloc-marker ";
+    OS << LS << " heap-alloc-marker ";
     HeapAllocMarker->printAsOperand(OS, MST);
-    NeedComma = true;
   }
   if (MDNode *PCSections = MI.getPCSections()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " pcsections ";
+    OS << LS << " pcsections ";
     PCSections->printAsOperand(OS, MST);
-    NeedComma = true;
   }
   if (MDNode *MMRA = MI.getMMRAMetadata()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " mmra ";
+    OS << LS << " mmra ";
     MMRA->printAsOperand(OS, MST);
-    NeedComma = true;
-  }
-  if (uint32_t CFIType = MI.getCFIType()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " cfi-type " << CFIType;
-    NeedComma = true;
   }
+  if (uint32_t CFIType = MI.getCFIType())
+    OS << LS << " cfi-type " << CFIType;
 
-  if (auto Num = MI.peekDebugInstrNum()) {
-    if (NeedComma)
-      OS << ',';
-    OS << " debug-instr-number " << Num;
-    NeedComma = true;
-  }
+  if (auto Num = MI.peekDebugInstrNum())
+    OS << LS << " debug-instr-number " << Num;
 
   if (PrintLocations) {
     if (const DebugLoc &DL = MI.getDebugLoc()) {
-      if (NeedComma)
-        OS << ',';
-      OS << " debug-location ";
+      OS << LS << " debug-location ";
       DL->printAsOperand(OS, MST);
     }
   }
@@ -945,12 +888,10 @@ void MIPrinter::print(const MachineInstr &MI) {
     OS << " :: ";
     const LLVMContext &Context = MF->getFunction().getContext();
     const MachineFrameInfo &MFI = MF->getFrameInfo();
-    bool NeedComma = false;
+    LS = ListSeparator();
     for (const auto *Op : MI.memoperands()) {
-      if (NeedComma)
-        OS << ", ";
+      OS << LS;
       Op->print(OS, MST, SSNs, Context, &MFI, TII);
-      NeedComma = true;
     }
   }
 }
@@ -971,10 +912,11 @@ static std::string formatOperandComment(std::string Comment) {
 }
 
 void MIPrinter::print(const MachineInstr &MI, unsigned OpIdx,
-                      const TargetRegisterInfo *TRI,
-                      const TargetInstrInfo *TII,
-                      bool ShouldPrintRegisterTies, LLT TypeToPrint,
-                      bool PrintDef) {
+                      const TargetRegisterInfo *TRI, const TargetInstrInfo *TII,
+                      bool ShouldPrintRegisterTies,
+                      SmallBitVector &PrintedTypes,
+                      const MachineRegisterInfo &MRI, bool PrintDef) {
+  LLT TypeToPrint = MI.getTypeToPrint(OpIdx, PrintedTypes, MRI);
   const MachineOperand &Op = MI.getOperand(OpIdx);
   std::string MOComment = TII->createMIROperandComment(MI, Op, OpIdx, TRI);
 
diff --git a/llvm/lib/CodeGen/MIRPrintingPass.cpp b/llvm/lib/CodeGen/MIRPrintingPass.cpp
index fc79410f97b58..28aeb7f116c6c 100644
--- a/llvm/lib/CodeGen/MIRPrintingPass.cpp
+++ b/llvm/lib/CodeGen/MIRPrintingPass.cpp
@@ -81,10 +81,6 @@ char MIRPrintingPass::ID = 0;
 char &llvm::MIRPrintingPassID = MIRPrintingPass::ID;
 INITIALIZE_PASS(MIRPrintingPass, "mir-printer", "MIR Printer", false, false)
 
-namespace llvm {
-
-MachineFunctionPass *createPrintMIRPass(raw_ostream &OS) {
+MachineFunctionPass *llvm::createPrintMIRPass(raw_ostream &OS) {
   return new MIRPrintingPass(OS);
 }
-
-} // end namespace llvm
diff --git a/llvm/lib/CodeGen/MachineBasicBlock.cpp b/llvm/lib/CodeGen/MachineBasicBlock.cpp
index fa6b53455f145..37fe37fd6e423 100644
--- a/llvm/lib/CodeGen/MachineBasicBlock.cpp
+++ b/llvm/lib/CodeGen/MachineBasicBlock.cpp
@@ -1587,20 +1587,36 @@ MachineBasicBlock::getSuccProbability(const_succ_iterator Succ) const {
     return BranchProbability(1, succ_size());
 
   const auto &Prob = *getProbabilityIterator(Succ);
-  if (Prob.isUnknown()) {
-    // For unknown probabilities, collect the sum of all known ones, and evenly
-    // ditribute the complemental of the sum to each unknown probability.
-    unsigned KnownProbNum = 0;
-    auto Sum = BranchProbability::getZero();
-    for (const auto &P : Probs) {
-      if (!P.isUnknown()) {
-        Sum += P;
-        KnownProbNum++;
-      }
-    }
-    return Sum.getCompl() / (Probs.size() - KnownProbNum);
-  } else
+  if (!Prob.isUnknown())
     return Prob;
+  // For unknown probabilities, collect the sum of all known ones, and evenly
+  // ditribute the complemental of the sum to each unknown probability.
+  unsigned KnownProbNum = 0;
+  auto Sum = BranchProbability::getZero();
+  for (const auto &P : Probs) {
+    if (!P.isUnknown()) {
+      Sum += P;
+      KnownProbNum++;
+    }
+  }
+  return Sum.getCompl() / (Probs.size() - KnownProbNum);
+}
+
+bool MachineBasicBlock::canPredictBranchProbabilities() const {
+  if (succ_size() <= 1)
+    return true;
+  if (!hasSuccessorProbabilities())
+    return true;
+
+  SmallVector<BranchProbability, 8> Normalized(Probs.begin(), Probs.end());
+  BranchProbability::normalizeProbabilities(Normalized);
+
+  // Normalize assuming unknown probabilities. This will assign equal
+  // probabilities to all successors.
+  SmallVector<BranchProbability, 8> Equal(Normalized.size());
+  BranchProbability::normalizeProbabilities(Equal);
+
+  return llvm::equal(Normalized, Equal);
 }
 
 /// Set successor probability of a given iterator.



More information about the llvm-commits mailing list