[llvm] [llvm] fix nullptr dereference in BasicBlock::getIrrLoopHeaderWeight (PR #116192)

Alexander Romanov via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 01:43:34 PST 2024


https://github.com/arrv-sc created https://github.com/llvm/llvm-project/pull/116192

Currently if you try to call `getIrrLoopHeaderWeight` on a `BasicBlock` that does not have any terminators variable `TI` will be equal to nullptr and you will get an undefined behaviour from dereferencing it. This commit adds a check for this pointer and returns `std::nullopt` if no terminators found.

>From 1ce9bd2fc13613eaa37efe927fcabb63550a687e Mon Sep 17 00:00:00 2001
From: Alexander Romanov <alexander.romanov at syntacore.com>
Date: Thu, 14 Nov 2024 12:37:24 +0300
Subject: [PATCH] [llvm] fix nullptr dereference in
 BasicBlock::getIrrLoopHeaderWeight Currently if you try to call
 `getIrrLoopHeaderWeight` on a `BasicBlock` that does not have any terminators
 variable `TI` will be equal to nullptr and you will get an undefined
 behaviour from dereferencing it. This commit adds a check for this pointer
 and returns `std::nullopt` if no terminators found.

---
 llvm/lib/IR/BasicBlock.cpp                    | 20 ++++++++++---------
 .../CodeGen/MachineBasicBlockTest.cpp         | 11 ++++++++++
 llvm/unittests/IR/BasicBlockTest.cpp          | 10 ++++++++++
 3 files changed, 32 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/IR/BasicBlock.cpp b/llvm/lib/IR/BasicBlock.cpp
index 0efc04cb2c8679..83428350b336b6 100644
--- a/llvm/lib/IR/BasicBlock.cpp
+++ b/llvm/lib/IR/BasicBlock.cpp
@@ -684,15 +684,17 @@ const LandingPadInst *BasicBlock::getLandingPadInst() const {
 
 std::optional<uint64_t> BasicBlock::getIrrLoopHeaderWeight() const {
   const Instruction *TI = getTerminator();
-  if (MDNode *MDIrrLoopHeader =
-      TI->getMetadata(LLVMContext::MD_irr_loop)) {
-    MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
-    if (MDName->getString() == "loop_header_weight") {
-      auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
-      return std::optional<uint64_t>(CI->getValue().getZExtValue());
-    }
-  }
-  return std::nullopt;
+  if (!TI)
+    return std::nullopt;
+  MDNode *MDIrrLoopHeader = TI->getMetadata(LLVMContext::MD_irr_loop);
+  if (!MDIrrLoopHeader)
+    return std::nullopt;
+  MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
+  assert(MDName);
+  if (MDName->getString() != "loop_header_weight")
+    return std::nullopt;
+  auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
+  return std::optional<uint64_t>(CI->getValue().getZExtValue());
 }
 
 BasicBlock::iterator llvm::skipDebugIntrinsics(BasicBlock::iterator It) {
diff --git a/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp b/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
index 25d54e8c80eec8..a1787f7de02a3a 100644
--- a/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
+++ b/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
@@ -105,4 +105,15 @@ TEST(FindDebugLocTest, DifferentIterators) {
   DIB.finalize();
 }
 
+TEST(MachineBasicBlockTest, EmptyBasicBlock) {
+  LLVMContext Ctx;
+  Module Mod("Module", Ctx);
+  // Test that it is possible to create MachineBasicBlock for BasicBlock
+  auto &F = MF->getFunction();
+  auto *BB = BasicBlock::Create(F.getContext(), "", &F);
+  auto *MBB = MF->CreateMachineBasicBlock(BB);
+  EXPECT_NE(MBB, nullptr);
+  EXPECT_EQ(MBB->size(), 0);
+  EXPECT_EQ(MBB->getBasicBlock(), BB);
+}
 } // end namespace
diff --git a/llvm/unittests/IR/BasicBlockTest.cpp b/llvm/unittests/IR/BasicBlockTest.cpp
index 88ac6611742ce9..fbed69c3197011 100644
--- a/llvm/unittests/IR/BasicBlockTest.cpp
+++ b/llvm/unittests/IR/BasicBlockTest.cpp
@@ -583,5 +583,15 @@ TEST(BasicBlockTest, DiscardValueNames2) {
   }
 }
 
+TEST(BasicBlockTest, IrrLoopHeaderNull) {
+  LLVMContext Ctx;
+  Module M("Mod", Ctx);
+  auto *FTy = FunctionType::get(Type::getVoidTy(M.getContext()),
+                                Type::getInt32Ty(Ctx), /* isVarArg */ false);
+  auto *F = Function::Create(FTy, Function::ExternalLinkage, "foo", &M);
+  auto *BB = BasicBlock::Create(Ctx, "", F);
+  EXPECT_EQ(BB->getIrrLoopHeaderWeight(), std::nullopt);
+}
+
 } // End anonymous namespace.
 } // End llvm namespace.



More information about the llvm-commits mailing list