[llvm] [CodeGen] Fix `MachineModuleInfo`'s move constructor to be more safe with `MCContext` ownership. (PR #104834)

weiwei chen via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 19:15:43 PDT 2024


https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/104834

>From f6d4993502657531bdacc5ca419862eea5f022e5 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 12:42:34 -0400
Subject: [PATCH 1/4] Make MMI in MachineModuleInfoWrapperPass a unique_ptr.

---
 llvm/include/llvm/CodeGen/MachineModuleInfo.h |  7 ++++---
 llvm/lib/CodeGen/MachineModuleInfo.cpp        | 12 ++++++------
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index 310cc4b2abb772..459444be8083c3 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -169,7 +169,8 @@ class MachineModuleInfo {
 }; // End class MachineModuleInfo
 
 class MachineModuleInfoWrapperPass : public ImmutablePass {
-  MachineModuleInfo MMI;
+  std::unique_ptr<MachineModuleInfo> MMI =
+      std::make_unique<MachineModuleInfo>();
 
 public:
   static char ID; // Pass identification, replacement for typeid
@@ -182,8 +183,8 @@ class MachineModuleInfoWrapperPass : public ImmutablePass {
   bool doInitialization(Module &) override;
   bool doFinalization(Module &) override;
 
-  MachineModuleInfo &getMMI() { return MMI; }
-  const MachineModuleInfo &getMMI() const { return MMI; }
+  MachineModuleInfo &getMMI() { return *MMI; }
+  const MachineModuleInfo &getMMI() const { return *MMI; }
 };
 
 /// An analysis that produces \c MachineModuleInfo for a module.
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index c66495969b4e67..c6249d3c1a237a 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -152,13 +152,13 @@ FunctionPass *llvm::createFreeMachineFunctionPass() {
 
 MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
     const LLVMTargetMachine *TM)
-    : ImmutablePass(ID), MMI(TM) {
+    : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM)) {
   initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
 }
 
 MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
     const LLVMTargetMachine *TM, MCContext *ExtContext)
-    : ImmutablePass(ID), MMI(TM, ExtContext) {
+    : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM, ExtContext)) {
   initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
 }
 
@@ -193,10 +193,10 @@ static uint64_t getLocCookie(const SMDiagnostic &SMD, const SourceMgr &SrcMgr,
 }
 
 bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
-  MMI.initialize();
-  MMI.TheModule = &M;
+  MMI->initialize();
+  MMI->TheModule = &M;
   LLVMContext &Ctx = M.getContext();
-  MMI.getContext().setDiagnosticHandler(
+  MMI->getContext().setDiagnosticHandler(
       [&Ctx, &M](const SMDiagnostic &SMD, bool IsInlineAsm,
                  const SourceMgr &SrcMgr,
                  std::vector<const MDNode *> &LocInfos) {
@@ -210,7 +210,7 @@ bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
 }
 
 bool MachineModuleInfoWrapperPass::doFinalization(Module &M) {
-  MMI.finalize();
+  MMI->finalize();
   return false;
 }
 

>From 6186520efebf9223d1e2574e193f4d92181ce23e Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 15:02:34 -0400
Subject: [PATCH 2/4] Make MCContext in MachineModuleInfo unique_ptr instead.

---
 llvm/include/llvm/CodeGen/MachineModuleInfo.h | 27 ++++++------
 llvm/lib/CodeGen/MachineModuleInfo.cpp        | 42 +++++++++----------
 2 files changed, 32 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index 459444be8083c3..e8ba8ab53fb7fd 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -66,7 +66,7 @@ class MachineModuleInfoImpl {
 protected:
   /// Return the entries from a DenseMap in a deterministic sorted orer.
   /// Clears the map.
-  static SymbolListTy getSortedStubs(DenseMap<MCSymbol*, StubValueTy>&);
+  static SymbolListTy getSortedStubs(DenseMap<MCSymbol *, StubValueTy> &);
 
   /// Return the entries from a DenseMap in a deterministic sorted orer.
   /// Clears the map.
@@ -86,7 +86,7 @@ class MachineModuleInfo {
   const LLVMTargetMachine &TM;
 
   /// This is the MCContext used for the entire code generator.
-  MCContext Context;
+  std::unique_ptr<MCContext> Context;
   // This is an external context, that if assigned, will be used instead of the
   // internal context.
   MCContext *ExternalContext = nullptr;
@@ -100,7 +100,7 @@ class MachineModuleInfo {
   MachineModuleInfoImpl *ObjFileMMI;
 
   /// Maps IR Functions to their corresponding MachineFunctions.
-  DenseMap<const Function*, std::unique_ptr<MachineFunction>> MachineFunctions;
+  DenseMap<const Function *, std::unique_ptr<MachineFunction>> MachineFunctions;
   /// Next unique number available for a MachineFunction.
   unsigned NextFnNum = 0;
   const Function *LastRequest = nullptr; ///< Used for shortcut/cache.
@@ -124,10 +124,10 @@ class MachineModuleInfo {
   const LLVMTargetMachine &getTarget() const { return TM; }
 
   const MCContext &getContext() const {
-    return ExternalContext ? *ExternalContext : Context;
+    return ExternalContext ? *ExternalContext : *Context;
   }
   MCContext &getContext() {
-    return ExternalContext ? *ExternalContext : Context;
+    return ExternalContext ? *ExternalContext : *Context;
   }
 
   const Module *getModule() const { return TheModule; }
@@ -153,24 +153,21 @@ class MachineModuleInfo {
 
   /// Keep track of various per-module pieces of information for backends
   /// that would like to do so.
-  template<typename Ty>
-  Ty &getObjFileInfo() {
+  template <typename Ty> Ty &getObjFileInfo() {
     if (ObjFileMMI == nullptr)
       ObjFileMMI = new Ty(*this);
-    return *static_cast<Ty*>(ObjFileMMI);
+    return *static_cast<Ty *>(ObjFileMMI);
   }
 
-  template<typename Ty>
-  const Ty &getObjFileInfo() const {
-    return const_cast<MachineModuleInfo*>(this)->getObjFileInfo<Ty>();
+  template <typename Ty> const Ty &getObjFileInfo() const {
+    return const_cast<MachineModuleInfo *>(this)->getObjFileInfo<Ty>();
   }
 
   /// \}
 }; // End class MachineModuleInfo
 
 class MachineModuleInfoWrapperPass : public ImmutablePass {
-  std::unique_ptr<MachineModuleInfo> MMI =
-      std::make_unique<MachineModuleInfo>();
+  MachineModuleInfo MMI;
 
 public:
   static char ID; // Pass identification, replacement for typeid
@@ -183,8 +180,8 @@ class MachineModuleInfoWrapperPass : public ImmutablePass {
   bool doInitialization(Module &) override;
   bool doFinalization(Module &) override;
 
-  MachineModuleInfo &getMMI() { return *MMI; }
-  const MachineModuleInfo &getMMI() const { return *MMI; }
+  MachineModuleInfo &getMMI() { return MMI; }
+  const MachineModuleInfo &getMMI() const { return MMI; }
 };
 
 /// An analysis that produces \c MachineModuleInfo for a module.
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index c6249d3c1a237a..3d87b6261f9324 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -30,7 +30,8 @@ void MachineModuleInfo::initialize() {
 }
 
 void MachineModuleInfo::finalize() {
-  Context.reset();
+  if (Context)
+    Context->reset();
   // We don't clear the ExternalContext.
 
   delete ObjFileMMI;
@@ -38,31 +39,30 @@ void MachineModuleInfo::finalize() {
 }
 
 MachineModuleInfo::MachineModuleInfo(MachineModuleInfo &&MMI)
-    : TM(std::move(MMI.TM)),
-      Context(TM.getTargetTriple(), TM.getMCAsmInfo(), TM.getMCRegisterInfo(),
-              TM.getMCSubtargetInfo(), nullptr, &TM.Options.MCOptions, false),
+    : TM(std::move(MMI.TM)), Context(std::move(MMI.Context)),
       MachineFunctions(std::move(MMI.MachineFunctions)) {
-  Context.setObjectFileInfo(TM.getObjFileLowering());
   ObjFileMMI = MMI.ObjFileMMI;
   ExternalContext = MMI.ExternalContext;
   TheModule = MMI.TheModule;
 }
 
 MachineModuleInfo::MachineModuleInfo(const LLVMTargetMachine *TM)
-    : TM(*TM), Context(TM->getTargetTriple(), TM->getMCAsmInfo(),
-                       TM->getMCRegisterInfo(), TM->getMCSubtargetInfo(),
-                       nullptr, &TM->Options.MCOptions, false) {
-  Context.setObjectFileInfo(TM->getObjFileLowering());
+    : TM(*TM),
+      Context(std::make_unique<MCContext>(
+          TM->getTargetTriple(), TM->getMCAsmInfo(), TM->getMCRegisterInfo(),
+          TM->getMCSubtargetInfo(), nullptr, &TM->Options.MCOptions, false)) {
+  Context->setObjectFileInfo(TM->getObjFileLowering());
   initialize();
 }
 
 MachineModuleInfo::MachineModuleInfo(const LLVMTargetMachine *TM,
                                      MCContext *ExtContext)
-    : TM(*TM), Context(TM->getTargetTriple(), TM->getMCAsmInfo(),
-                       TM->getMCRegisterInfo(), TM->getMCSubtargetInfo(),
-                       nullptr, &TM->Options.MCOptions, false),
+    : TM(*TM),
+      Context(std::make_unique<MCContext>(
+          TM->getTargetTriple(), TM->getMCAsmInfo(), TM->getMCRegisterInfo(),
+          TM->getMCSubtargetInfo(), nullptr, &TM->Options.MCOptions, false)),
       ExternalContext(ExtContext) {
-  Context.setObjectFileInfo(TM->getObjFileLowering());
+  Context->setObjectFileInfo(TM->getObjFileLowering());
   initialize();
 }
 
@@ -137,9 +137,7 @@ class FreeMachineFunction : public FunctionPass {
     return true;
   }
 
-  StringRef getPassName() const override {
-    return "Free MachineFunction";
-  }
+  StringRef getPassName() const override { return "Free MachineFunction"; }
 };
 
 } // end anonymous namespace
@@ -152,13 +150,13 @@ FunctionPass *llvm::createFreeMachineFunctionPass() {
 
 MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
     const LLVMTargetMachine *TM)
-    : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM)) {
+    : ImmutablePass(ID), MMI(TM) {
   initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
 }
 
 MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
     const LLVMTargetMachine *TM, MCContext *ExtContext)
-    : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM, ExtContext)) {
+    : ImmutablePass(ID), MMI(TM) {
   initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
 }
 
@@ -193,10 +191,10 @@ static uint64_t getLocCookie(const SMDiagnostic &SMD, const SourceMgr &SrcMgr,
 }
 
 bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
-  MMI->initialize();
-  MMI->TheModule = &M;
+  MMI.initialize();
+  MMI.TheModule = &M;
   LLVMContext &Ctx = M.getContext();
-  MMI->getContext().setDiagnosticHandler(
+  MMI.getContext().setDiagnosticHandler(
       [&Ctx, &M](const SMDiagnostic &SMD, bool IsInlineAsm,
                  const SourceMgr &SrcMgr,
                  std::vector<const MDNode *> &LocInfos) {
@@ -210,7 +208,7 @@ bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
 }
 
 bool MachineModuleInfoWrapperPass::doFinalization(Module &M) {
-  MMI->finalize();
+  MMI.finalize();
   return false;
 }
 

>From 9b6e6a053297920bb9b3b1e24e160d1c411aa7da Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 15:09:28 -0400
Subject: [PATCH 3/4] Redo format.

---
 llvm/include/llvm/CodeGen/MachineModuleInfo.h | 14 ++++++++------
 llvm/lib/CodeGen/MachineModuleInfo.cpp        |  6 ++++--
 2 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index e8ba8ab53fb7fd..2f0ece30f90786 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -66,7 +66,7 @@ class MachineModuleInfoImpl {
 protected:
   /// Return the entries from a DenseMap in a deterministic sorted orer.
   /// Clears the map.
-  static SymbolListTy getSortedStubs(DenseMap<MCSymbol *, StubValueTy> &);
+  static SymbolListTy getSortedStubs(DenseMap<MCSymbol*, StubValueTy>&);
 
   /// Return the entries from a DenseMap in a deterministic sorted orer.
   /// Clears the map.
@@ -100,7 +100,7 @@ class MachineModuleInfo {
   MachineModuleInfoImpl *ObjFileMMI;
 
   /// Maps IR Functions to their corresponding MachineFunctions.
-  DenseMap<const Function *, std::unique_ptr<MachineFunction>> MachineFunctions;
+  DenseMap<const Function*, std::unique_ptr<MachineFunction>> MachineFunctions;
   /// Next unique number available for a MachineFunction.
   unsigned NextFnNum = 0;
   const Function *LastRequest = nullptr; ///< Used for shortcut/cache.
@@ -153,14 +153,16 @@ class MachineModuleInfo {
 
   /// Keep track of various per-module pieces of information for backends
   /// that would like to do so.
-  template <typename Ty> Ty &getObjFileInfo() {
+  template<typename Ty>
+  Ty &getObjFileInfo() {
     if (ObjFileMMI == nullptr)
       ObjFileMMI = new Ty(*this);
-    return *static_cast<Ty *>(ObjFileMMI);
+    return *static_cast<Ty*>(ObjFileMMI);
   }
 
-  template <typename Ty> const Ty &getObjFileInfo() const {
-    return const_cast<MachineModuleInfo *>(this)->getObjFileInfo<Ty>();
+  template<typename Ty>
+  const Ty &getObjFileInfo() const {
+    return const_cast<MachineModuleInfo*>(this)->getObjFileInfo<Ty>();
   }
 
   /// \}
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index 3d87b6261f9324..4c757b62a112a1 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -137,7 +137,9 @@ class FreeMachineFunction : public FunctionPass {
     return true;
   }
 
-  StringRef getPassName() const override { return "Free MachineFunction"; }
+  StringRef getPassName() const override {
+    return "Free MachineFunction";
+  }
 };
 
 } // end anonymous namespace
@@ -156,7 +158,7 @@ MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
 
 MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
     const LLVMTargetMachine *TM, MCContext *ExtContext)
-    : ImmutablePass(ID), MMI(TM) {
+    : ImmutablePass(ID), MMI(TM, ExtContext) {
   initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
 }
 

>From 160845f17b0ba82ca0f3be55f35e3a911a5ac3cf Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 22:14:40 -0400
Subject: [PATCH 4/4] Add a unittest for the move constructor change.

---
 llvm/unittests/CodeGen/CMakeLists.txt         |  1 +
 .../CodeGen/MachineModuleInfoTest.cpp         | 86 +++++++++++++++++++
 2 files changed, 87 insertions(+)
 create mode 100644 llvm/unittests/CodeGen/MachineModuleInfoTest.cpp

diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index 963cdcc0275e16..ac3806a95cc488 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -35,6 +35,7 @@ add_llvm_unittest(CodeGenTests
   MachineDomTreeUpdaterTest.cpp
   MachineInstrBundleIteratorTest.cpp
   MachineInstrTest.cpp
+  MachineModuleInfoTest.cpp
   MachineOperandTest.cpp
   RegAllocScoreTest.cpp
   PassManagerTest.cpp
diff --git a/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp b/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp
new file mode 100644
index 00000000000000..51fe57c59b48b8
--- /dev/null
+++ b/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp
@@ -0,0 +1,86 @@
+//===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Test MachineModuleInfo.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
+  SMDiagnostic Err;
+  return parseAssemblyString(IR, Err, Context);
+}
+
+class MachineModuleInfoTest: public ::testing::Test {
+protected:
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<TargetMachine> TM;
+
+public:
+  MachineModuleInfoTest()
+      : M(parseIR(Context, "define void @f() {\n"
+                           "entry:\n"
+                           "  call void @g()\n"
+                           "  ret void\n"
+                           "}\n"
+                           "define void @g() {\n"
+                           "  ret void\n"
+                           "}\n")) {
+    // MachineModuleAnalysis needs a TargetMachine instance.
+    llvm::InitializeAllTargets();
+
+    std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+    std::string Error;
+    const Target *TheTarget =
+        TargetRegistry::lookupTarget(TripleName, Error);
+    if (!TheTarget)
+      return;
+
+    TargetOptions Options;
+    TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
+                                            std::nullopt));
+  }
+};
+
+TEST_F(MachineModuleInfoTest, MachineModuleInfoMoveConstructorMovesMCContext) {
+  if (!TM)
+    GTEST_SKIP();
+
+  LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
+  M->setDataLayout(TM->createDataLayout());
+
+  MachineModuleInfo MMI(LLVMTM);
+
+  MCContext* OriginalCtx = &MMI.getContext();
+
+  MachineModuleInfo MovedMMI(std::move(MMI));
+  MCContext* MovedCtx = &MovedMMI.getContext();
+
+  EXPECT_EQ(OriginalCtx, MovedCtx);
+}
+
+} // end namespace



More information about the llvm-commits mailing list