[llvm] [CodeGen] Fix `MachineModuleInfo`'s move constructor to be more safe with `MCContext` ownership. (PR #104834)
weiwei chen via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 19:15:43 PDT 2024
https://github.com/weiweichen updated https://github.com/llvm/llvm-project/pull/104834
>From f6d4993502657531bdacc5ca419862eea5f022e5 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 12:42:34 -0400
Subject: [PATCH 1/4] Make MMI in MachineModuleInfoWrapperPass a unique_ptr.
---
llvm/include/llvm/CodeGen/MachineModuleInfo.h | 7 ++++---
llvm/lib/CodeGen/MachineModuleInfo.cpp | 12 ++++++------
2 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index 310cc4b2abb772..459444be8083c3 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -169,7 +169,8 @@ class MachineModuleInfo {
}; // End class MachineModuleInfo
class MachineModuleInfoWrapperPass : public ImmutablePass {
- MachineModuleInfo MMI;
+ std::unique_ptr<MachineModuleInfo> MMI =
+ std::make_unique<MachineModuleInfo>();
public:
static char ID; // Pass identification, replacement for typeid
@@ -182,8 +183,8 @@ class MachineModuleInfoWrapperPass : public ImmutablePass {
bool doInitialization(Module &) override;
bool doFinalization(Module &) override;
- MachineModuleInfo &getMMI() { return MMI; }
- const MachineModuleInfo &getMMI() const { return MMI; }
+ MachineModuleInfo &getMMI() { return *MMI; }
+ const MachineModuleInfo &getMMI() const { return *MMI; }
};
/// An analysis that produces \c MachineModuleInfo for a module.
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index c66495969b4e67..c6249d3c1a237a 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -152,13 +152,13 @@ FunctionPass *llvm::createFreeMachineFunctionPass() {
MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
const LLVMTargetMachine *TM)
- : ImmutablePass(ID), MMI(TM) {
+ : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM)) {
initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
const LLVMTargetMachine *TM, MCContext *ExtContext)
- : ImmutablePass(ID), MMI(TM, ExtContext) {
+ : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM, ExtContext)) {
initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
@@ -193,10 +193,10 @@ static uint64_t getLocCookie(const SMDiagnostic &SMD, const SourceMgr &SrcMgr,
}
bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
- MMI.initialize();
- MMI.TheModule = &M;
+ MMI->initialize();
+ MMI->TheModule = &M;
LLVMContext &Ctx = M.getContext();
- MMI.getContext().setDiagnosticHandler(
+ MMI->getContext().setDiagnosticHandler(
[&Ctx, &M](const SMDiagnostic &SMD, bool IsInlineAsm,
const SourceMgr &SrcMgr,
std::vector<const MDNode *> &LocInfos) {
@@ -210,7 +210,7 @@ bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
}
bool MachineModuleInfoWrapperPass::doFinalization(Module &M) {
- MMI.finalize();
+ MMI->finalize();
return false;
}
>From 6186520efebf9223d1e2574e193f4d92181ce23e Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 15:02:34 -0400
Subject: [PATCH 2/4] Make MCContext in MachineModuleInfo unique_ptr instead.
---
llvm/include/llvm/CodeGen/MachineModuleInfo.h | 27 ++++++------
llvm/lib/CodeGen/MachineModuleInfo.cpp | 42 +++++++++----------
2 files changed, 32 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index 459444be8083c3..e8ba8ab53fb7fd 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -66,7 +66,7 @@ class MachineModuleInfoImpl {
protected:
/// Return the entries from a DenseMap in a deterministic sorted orer.
/// Clears the map.
- static SymbolListTy getSortedStubs(DenseMap<MCSymbol*, StubValueTy>&);
+ static SymbolListTy getSortedStubs(DenseMap<MCSymbol *, StubValueTy> &);
/// Return the entries from a DenseMap in a deterministic sorted orer.
/// Clears the map.
@@ -86,7 +86,7 @@ class MachineModuleInfo {
const LLVMTargetMachine &TM;
/// This is the MCContext used for the entire code generator.
- MCContext Context;
+ std::unique_ptr<MCContext> Context;
// This is an external context, that if assigned, will be used instead of the
// internal context.
MCContext *ExternalContext = nullptr;
@@ -100,7 +100,7 @@ class MachineModuleInfo {
MachineModuleInfoImpl *ObjFileMMI;
/// Maps IR Functions to their corresponding MachineFunctions.
- DenseMap<const Function*, std::unique_ptr<MachineFunction>> MachineFunctions;
+ DenseMap<const Function *, std::unique_ptr<MachineFunction>> MachineFunctions;
/// Next unique number available for a MachineFunction.
unsigned NextFnNum = 0;
const Function *LastRequest = nullptr; ///< Used for shortcut/cache.
@@ -124,10 +124,10 @@ class MachineModuleInfo {
const LLVMTargetMachine &getTarget() const { return TM; }
const MCContext &getContext() const {
- return ExternalContext ? *ExternalContext : Context;
+ return ExternalContext ? *ExternalContext : *Context;
}
MCContext &getContext() {
- return ExternalContext ? *ExternalContext : Context;
+ return ExternalContext ? *ExternalContext : *Context;
}
const Module *getModule() const { return TheModule; }
@@ -153,24 +153,21 @@ class MachineModuleInfo {
/// Keep track of various per-module pieces of information for backends
/// that would like to do so.
- template<typename Ty>
- Ty &getObjFileInfo() {
+ template <typename Ty> Ty &getObjFileInfo() {
if (ObjFileMMI == nullptr)
ObjFileMMI = new Ty(*this);
- return *static_cast<Ty*>(ObjFileMMI);
+ return *static_cast<Ty *>(ObjFileMMI);
}
- template<typename Ty>
- const Ty &getObjFileInfo() const {
- return const_cast<MachineModuleInfo*>(this)->getObjFileInfo<Ty>();
+ template <typename Ty> const Ty &getObjFileInfo() const {
+ return const_cast<MachineModuleInfo *>(this)->getObjFileInfo<Ty>();
}
/// \}
}; // End class MachineModuleInfo
class MachineModuleInfoWrapperPass : public ImmutablePass {
- std::unique_ptr<MachineModuleInfo> MMI =
- std::make_unique<MachineModuleInfo>();
+ MachineModuleInfo MMI;
public:
static char ID; // Pass identification, replacement for typeid
@@ -183,8 +180,8 @@ class MachineModuleInfoWrapperPass : public ImmutablePass {
bool doInitialization(Module &) override;
bool doFinalization(Module &) override;
- MachineModuleInfo &getMMI() { return *MMI; }
- const MachineModuleInfo &getMMI() const { return *MMI; }
+ MachineModuleInfo &getMMI() { return MMI; }
+ const MachineModuleInfo &getMMI() const { return MMI; }
};
/// An analysis that produces \c MachineModuleInfo for a module.
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index c6249d3c1a237a..3d87b6261f9324 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -30,7 +30,8 @@ void MachineModuleInfo::initialize() {
}
void MachineModuleInfo::finalize() {
- Context.reset();
+ if (Context)
+ Context->reset();
// We don't clear the ExternalContext.
delete ObjFileMMI;
@@ -38,31 +39,30 @@ void MachineModuleInfo::finalize() {
}
MachineModuleInfo::MachineModuleInfo(MachineModuleInfo &&MMI)
- : TM(std::move(MMI.TM)),
- Context(TM.getTargetTriple(), TM.getMCAsmInfo(), TM.getMCRegisterInfo(),
- TM.getMCSubtargetInfo(), nullptr, &TM.Options.MCOptions, false),
+ : TM(std::move(MMI.TM)), Context(std::move(MMI.Context)),
MachineFunctions(std::move(MMI.MachineFunctions)) {
- Context.setObjectFileInfo(TM.getObjFileLowering());
ObjFileMMI = MMI.ObjFileMMI;
ExternalContext = MMI.ExternalContext;
TheModule = MMI.TheModule;
}
MachineModuleInfo::MachineModuleInfo(const LLVMTargetMachine *TM)
- : TM(*TM), Context(TM->getTargetTriple(), TM->getMCAsmInfo(),
- TM->getMCRegisterInfo(), TM->getMCSubtargetInfo(),
- nullptr, &TM->Options.MCOptions, false) {
- Context.setObjectFileInfo(TM->getObjFileLowering());
+ : TM(*TM),
+ Context(std::make_unique<MCContext>(
+ TM->getTargetTriple(), TM->getMCAsmInfo(), TM->getMCRegisterInfo(),
+ TM->getMCSubtargetInfo(), nullptr, &TM->Options.MCOptions, false)) {
+ Context->setObjectFileInfo(TM->getObjFileLowering());
initialize();
}
MachineModuleInfo::MachineModuleInfo(const LLVMTargetMachine *TM,
MCContext *ExtContext)
- : TM(*TM), Context(TM->getTargetTriple(), TM->getMCAsmInfo(),
- TM->getMCRegisterInfo(), TM->getMCSubtargetInfo(),
- nullptr, &TM->Options.MCOptions, false),
+ : TM(*TM),
+ Context(std::make_unique<MCContext>(
+ TM->getTargetTriple(), TM->getMCAsmInfo(), TM->getMCRegisterInfo(),
+ TM->getMCSubtargetInfo(), nullptr, &TM->Options.MCOptions, false)),
ExternalContext(ExtContext) {
- Context.setObjectFileInfo(TM->getObjFileLowering());
+ Context->setObjectFileInfo(TM->getObjFileLowering());
initialize();
}
@@ -137,9 +137,7 @@ class FreeMachineFunction : public FunctionPass {
return true;
}
- StringRef getPassName() const override {
- return "Free MachineFunction";
- }
+ StringRef getPassName() const override { return "Free MachineFunction"; }
};
} // end anonymous namespace
@@ -152,13 +150,13 @@ FunctionPass *llvm::createFreeMachineFunctionPass() {
MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
const LLVMTargetMachine *TM)
- : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM)) {
+ : ImmutablePass(ID), MMI(TM) {
initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
const LLVMTargetMachine *TM, MCContext *ExtContext)
- : ImmutablePass(ID), MMI(std::make_unique<MachineModuleInfo>(TM, ExtContext)) {
+ : ImmutablePass(ID), MMI(TM) {
initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
@@ -193,10 +191,10 @@ static uint64_t getLocCookie(const SMDiagnostic &SMD, const SourceMgr &SrcMgr,
}
bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
- MMI->initialize();
- MMI->TheModule = &M;
+ MMI.initialize();
+ MMI.TheModule = &M;
LLVMContext &Ctx = M.getContext();
- MMI->getContext().setDiagnosticHandler(
+ MMI.getContext().setDiagnosticHandler(
[&Ctx, &M](const SMDiagnostic &SMD, bool IsInlineAsm,
const SourceMgr &SrcMgr,
std::vector<const MDNode *> &LocInfos) {
@@ -210,7 +208,7 @@ bool MachineModuleInfoWrapperPass::doInitialization(Module &M) {
}
bool MachineModuleInfoWrapperPass::doFinalization(Module &M) {
- MMI->finalize();
+ MMI.finalize();
return false;
}
>From 9b6e6a053297920bb9b3b1e24e160d1c411aa7da Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 15:09:28 -0400
Subject: [PATCH 3/4] Redo format.
---
llvm/include/llvm/CodeGen/MachineModuleInfo.h | 14 ++++++++------
llvm/lib/CodeGen/MachineModuleInfo.cpp | 6 ++++--
2 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineModuleInfo.h b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
index e8ba8ab53fb7fd..2f0ece30f90786 100644
--- a/llvm/include/llvm/CodeGen/MachineModuleInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineModuleInfo.h
@@ -66,7 +66,7 @@ class MachineModuleInfoImpl {
protected:
/// Return the entries from a DenseMap in a deterministic sorted orer.
/// Clears the map.
- static SymbolListTy getSortedStubs(DenseMap<MCSymbol *, StubValueTy> &);
+ static SymbolListTy getSortedStubs(DenseMap<MCSymbol*, StubValueTy>&);
/// Return the entries from a DenseMap in a deterministic sorted orer.
/// Clears the map.
@@ -100,7 +100,7 @@ class MachineModuleInfo {
MachineModuleInfoImpl *ObjFileMMI;
/// Maps IR Functions to their corresponding MachineFunctions.
- DenseMap<const Function *, std::unique_ptr<MachineFunction>> MachineFunctions;
+ DenseMap<const Function*, std::unique_ptr<MachineFunction>> MachineFunctions;
/// Next unique number available for a MachineFunction.
unsigned NextFnNum = 0;
const Function *LastRequest = nullptr; ///< Used for shortcut/cache.
@@ -153,14 +153,16 @@ class MachineModuleInfo {
/// Keep track of various per-module pieces of information for backends
/// that would like to do so.
- template <typename Ty> Ty &getObjFileInfo() {
+ template<typename Ty>
+ Ty &getObjFileInfo() {
if (ObjFileMMI == nullptr)
ObjFileMMI = new Ty(*this);
- return *static_cast<Ty *>(ObjFileMMI);
+ return *static_cast<Ty*>(ObjFileMMI);
}
- template <typename Ty> const Ty &getObjFileInfo() const {
- return const_cast<MachineModuleInfo *>(this)->getObjFileInfo<Ty>();
+ template<typename Ty>
+ const Ty &getObjFileInfo() const {
+ return const_cast<MachineModuleInfo*>(this)->getObjFileInfo<Ty>();
}
/// \}
diff --git a/llvm/lib/CodeGen/MachineModuleInfo.cpp b/llvm/lib/CodeGen/MachineModuleInfo.cpp
index 3d87b6261f9324..4c757b62a112a1 100644
--- a/llvm/lib/CodeGen/MachineModuleInfo.cpp
+++ b/llvm/lib/CodeGen/MachineModuleInfo.cpp
@@ -137,7 +137,9 @@ class FreeMachineFunction : public FunctionPass {
return true;
}
- StringRef getPassName() const override { return "Free MachineFunction"; }
+ StringRef getPassName() const override {
+ return "Free MachineFunction";
+ }
};
} // end anonymous namespace
@@ -156,7 +158,7 @@ MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
MachineModuleInfoWrapperPass::MachineModuleInfoWrapperPass(
const LLVMTargetMachine *TM, MCContext *ExtContext)
- : ImmutablePass(ID), MMI(TM) {
+ : ImmutablePass(ID), MMI(TM, ExtContext) {
initializeMachineModuleInfoWrapperPassPass(*PassRegistry::getPassRegistry());
}
>From 160845f17b0ba82ca0f3be55f35e3a911a5ac3cf Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Mon, 19 Aug 2024 22:14:40 -0400
Subject: [PATCH 4/4] Add a unittest for the move constructor change.
---
llvm/unittests/CodeGen/CMakeLists.txt | 1 +
.../CodeGen/MachineModuleInfoTest.cpp | 86 +++++++++++++++++++
2 files changed, 87 insertions(+)
create mode 100644 llvm/unittests/CodeGen/MachineModuleInfoTest.cpp
diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index 963cdcc0275e16..ac3806a95cc488 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -35,6 +35,7 @@ add_llvm_unittest(CodeGenTests
MachineDomTreeUpdaterTest.cpp
MachineInstrBundleIteratorTest.cpp
MachineInstrTest.cpp
+ MachineModuleInfoTest.cpp
MachineOperandTest.cpp
RegAllocScoreTest.cpp
PassManagerTest.cpp
diff --git a/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp b/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp
new file mode 100644
index 00000000000000..51fe57c59b48b8
--- /dev/null
+++ b/llvm/unittests/CodeGen/MachineModuleInfoTest.cpp
@@ -0,0 +1,86 @@
+//===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Test MachineModuleInfo.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
+ SMDiagnostic Err;
+ return parseAssemblyString(IR, Err, Context);
+}
+
+class MachineModuleInfoTest: public ::testing::Test {
+protected:
+ LLVMContext Context;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<TargetMachine> TM;
+
+public:
+ MachineModuleInfoTest()
+ : M(parseIR(Context, "define void @f() {\n"
+ "entry:\n"
+ " call void @g()\n"
+ " ret void\n"
+ "}\n"
+ "define void @g() {\n"
+ " ret void\n"
+ "}\n")) {
+ // MachineModuleAnalysis needs a TargetMachine instance.
+ llvm::InitializeAllTargets();
+
+ std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+ std::string Error;
+ const Target *TheTarget =
+ TargetRegistry::lookupTarget(TripleName, Error);
+ if (!TheTarget)
+ return;
+
+ TargetOptions Options;
+ TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
+ std::nullopt));
+ }
+};
+
+TEST_F(MachineModuleInfoTest, MachineModuleInfoMoveConstructorMovesMCContext) {
+ if (!TM)
+ GTEST_SKIP();
+
+ LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
+ M->setDataLayout(TM->createDataLayout());
+
+ MachineModuleInfo MMI(LLVMTM);
+
+ MCContext* OriginalCtx = &MMI.getContext();
+
+ MachineModuleInfo MovedMMI(std::move(MMI));
+ MCContext* MovedCtx = &MovedMMI.getContext();
+
+ EXPECT_EQ(OriginalCtx, MovedCtx);
+}
+
+} // end namespace
More information about the llvm-commits
mailing list