[llvm] [llvm] fix nullptr dereference in BasicBlock::getIrrLoopHeaderWeight (PR #116192)

Alexander Romanov via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 00:11:04 PST 2024


https://github.com/arrv-sc updated https://github.com/llvm/llvm-project/pull/116192

>From 5cfbef9d796722b154a6aca29c24adeaa7bbaeb6 Mon Sep 17 00:00:00 2001
From: Alexander Romanov <alexander.romanov at syntacore.com>
Date: Mon, 18 Nov 2024 11:10:00 +0300
Subject: [PATCH] [llvm] fix nullptr dereference in
 BasicBlock::getIrrLoopHeaderWeight Currently if you try to call
 `getIrrLoopHeaderWeight` on a `BasicBlock` that does not have any terminators
 variable `TI` will be equal to nullptr and you will get an undefined
 behaviour from dereferencing it. This commit adds a check for this pointer
 and returns `std::nullopt` if no terminators found.

---
 llvm/lib/IR/BasicBlock.cpp           | 20 +++++++++++---------
 llvm/unittests/IR/BasicBlockTest.cpp | 10 ++++++++++
 2 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/IR/BasicBlock.cpp b/llvm/lib/IR/BasicBlock.cpp
index 0efc04cb2c8679..83428350b336b6 100644
--- a/llvm/lib/IR/BasicBlock.cpp
+++ b/llvm/lib/IR/BasicBlock.cpp
@@ -684,15 +684,17 @@ const LandingPadInst *BasicBlock::getLandingPadInst() const {
 
 std::optional<uint64_t> BasicBlock::getIrrLoopHeaderWeight() const {
   const Instruction *TI = getTerminator();
-  if (MDNode *MDIrrLoopHeader =
-      TI->getMetadata(LLVMContext::MD_irr_loop)) {
-    MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
-    if (MDName->getString() == "loop_header_weight") {
-      auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
-      return std::optional<uint64_t>(CI->getValue().getZExtValue());
-    }
-  }
-  return std::nullopt;
+  if (!TI)
+    return std::nullopt;
+  MDNode *MDIrrLoopHeader = TI->getMetadata(LLVMContext::MD_irr_loop);
+  if (!MDIrrLoopHeader)
+    return std::nullopt;
+  MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
+  assert(MDName);
+  if (MDName->getString() != "loop_header_weight")
+    return std::nullopt;
+  auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
+  return std::optional<uint64_t>(CI->getValue().getZExtValue());
 }
 
 BasicBlock::iterator llvm::skipDebugIntrinsics(BasicBlock::iterator It) {
diff --git a/llvm/unittests/IR/BasicBlockTest.cpp b/llvm/unittests/IR/BasicBlockTest.cpp
index 88ac6611742ce9..fbed69c3197011 100644
--- a/llvm/unittests/IR/BasicBlockTest.cpp
+++ b/llvm/unittests/IR/BasicBlockTest.cpp
@@ -583,5 +583,15 @@ TEST(BasicBlockTest, DiscardValueNames2) {
   }
 }
 
+TEST(BasicBlockTest, IrrLoopHeaderNull) {
+  LLVMContext Ctx;
+  Module M("Mod", Ctx);
+  auto *FTy = FunctionType::get(Type::getVoidTy(M.getContext()),
+                                Type::getInt32Ty(Ctx), /* isVarArg */ false);
+  auto *F = Function::Create(FTy, Function::ExternalLinkage, "foo", &M);
+  auto *BB = BasicBlock::Create(Ctx, "", F);
+  EXPECT_EQ(BB->getIrrLoopHeaderWeight(), std::nullopt);
+}
+
 } // End anonymous namespace.
 } // End llvm namespace.



More information about the llvm-commits mailing list