[llvm] [llvm] fix nullptr dereference in BasicBlock::getIrrLoopHeaderWeight (PR #116192)
Alexander Romanov via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 5 04:27:58 PST 2024
https://github.com/arrv-sc updated https://github.com/llvm/llvm-project/pull/116192
>From 2b6ee1110258361b20668b5ca20a5ef1fb4aefa3 Mon Sep 17 00:00:00 2001
From: Alexander Romanov <alexander.romanov at syntacore.com>
Date: Mon, 18 Nov 2024 11:10:00 +0300
Subject: [PATCH] [llvm] fix nullptr dereference in
BasicBlock::getIrrLoopHeaderWeight Currently if you try to call
`getIrrLoopHeaderWeight` on a `BasicBlock` that does not have any terminators
variable `TI` will be equal to nullptr and you will get an undefined
behaviour from dereferencing it. This commit adds a check for this pointer
and returns `std::nullopt` if no terminators found.
---
llvm/lib/IR/BasicBlock.cpp | 20 +++++++++++---------
llvm/unittests/IR/BasicBlockTest.cpp | 10 ++++++++++
2 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/IR/BasicBlock.cpp b/llvm/lib/IR/BasicBlock.cpp
index 0efc04cb2c8679..83428350b336b6 100644
--- a/llvm/lib/IR/BasicBlock.cpp
+++ b/llvm/lib/IR/BasicBlock.cpp
@@ -684,15 +684,17 @@ const LandingPadInst *BasicBlock::getLandingPadInst() const {
std::optional<uint64_t> BasicBlock::getIrrLoopHeaderWeight() const {
const Instruction *TI = getTerminator();
- if (MDNode *MDIrrLoopHeader =
- TI->getMetadata(LLVMContext::MD_irr_loop)) {
- MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
- if (MDName->getString() == "loop_header_weight") {
- auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
- return std::optional<uint64_t>(CI->getValue().getZExtValue());
- }
- }
- return std::nullopt;
+ if (!TI)
+ return std::nullopt;
+ MDNode *MDIrrLoopHeader = TI->getMetadata(LLVMContext::MD_irr_loop);
+ if (!MDIrrLoopHeader)
+ return std::nullopt;
+ MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
+ assert(MDName);
+ if (MDName->getString() != "loop_header_weight")
+ return std::nullopt;
+ auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
+ return std::optional<uint64_t>(CI->getValue().getZExtValue());
}
BasicBlock::iterator llvm::skipDebugIntrinsics(BasicBlock::iterator It) {
diff --git a/llvm/unittests/IR/BasicBlockTest.cpp b/llvm/unittests/IR/BasicBlockTest.cpp
index 36e849471d1ed8..934205a1103251 100644
--- a/llvm/unittests/IR/BasicBlockTest.cpp
+++ b/llvm/unittests/IR/BasicBlockTest.cpp
@@ -583,5 +583,15 @@ TEST(BasicBlockTest, DiscardValueNames2) {
}
}
+TEST(BasicBlockTest, IrrLoopHeaderNull) {
+ LLVMContext Ctx;
+ Module M("Mod", Ctx);
+ auto *FTy = FunctionType::get(Type::getVoidTy(M.getContext()),
+ Type::getInt32Ty(Ctx), /* isVarArg */ false);
+ auto *F = Function::Create(FTy, Function::ExternalLinkage, "foo", &M);
+ auto *BB = BasicBlock::Create(Ctx, "", F);
+ EXPECT_EQ(BB->getIrrLoopHeaderWeight(), std::nullopt);
+}
+
} // End anonymous namespace.
} // End llvm namespace.
More information about the llvm-commits
mailing list