[llvm] [llvm] fix nullptr dereference in BasicBlock::getIrrLoopHeaderWeight (PR #116192)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 01:44:08 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Alexander Romanov (arrv-sc)

<details>
<summary>Changes</summary>

Currently if you try to call `getIrrLoopHeaderWeight` on a `BasicBlock` that does not have any terminators variable `TI` will be equal to nullptr and you will get an undefined behaviour from dereferencing it. This commit adds a check for this pointer and returns `std::nullopt` if no terminators found.

---
Full diff: https://github.com/llvm/llvm-project/pull/116192.diff


3 Files Affected:

- (modified) llvm/lib/IR/BasicBlock.cpp (+11-9) 
- (modified) llvm/unittests/CodeGen/MachineBasicBlockTest.cpp (+11) 
- (modified) llvm/unittests/IR/BasicBlockTest.cpp (+10) 


``````````diff
diff --git a/llvm/lib/IR/BasicBlock.cpp b/llvm/lib/IR/BasicBlock.cpp
index 0efc04cb2c8679..83428350b336b6 100644
--- a/llvm/lib/IR/BasicBlock.cpp
+++ b/llvm/lib/IR/BasicBlock.cpp
@@ -684,15 +684,17 @@ const LandingPadInst *BasicBlock::getLandingPadInst() const {
 
 std::optional<uint64_t> BasicBlock::getIrrLoopHeaderWeight() const {
   const Instruction *TI = getTerminator();
-  if (MDNode *MDIrrLoopHeader =
-      TI->getMetadata(LLVMContext::MD_irr_loop)) {
-    MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
-    if (MDName->getString() == "loop_header_weight") {
-      auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
-      return std::optional<uint64_t>(CI->getValue().getZExtValue());
-    }
-  }
-  return std::nullopt;
+  if (!TI)
+    return std::nullopt;
+  MDNode *MDIrrLoopHeader = TI->getMetadata(LLVMContext::MD_irr_loop);
+  if (!MDIrrLoopHeader)
+    return std::nullopt;
+  MDString *MDName = cast<MDString>(MDIrrLoopHeader->getOperand(0));
+  assert(MDName);
+  if (MDName->getString() != "loop_header_weight")
+    return std::nullopt;
+  auto *CI = mdconst::extract<ConstantInt>(MDIrrLoopHeader->getOperand(1));
+  return std::optional<uint64_t>(CI->getValue().getZExtValue());
 }
 
 BasicBlock::iterator llvm::skipDebugIntrinsics(BasicBlock::iterator It) {
diff --git a/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp b/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
index 25d54e8c80eec8..a1787f7de02a3a 100644
--- a/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
+++ b/llvm/unittests/CodeGen/MachineBasicBlockTest.cpp
@@ -105,4 +105,15 @@ TEST(FindDebugLocTest, DifferentIterators) {
   DIB.finalize();
 }
 
+TEST(MachineBasicBlockTest, EmptyBasicBlock) {
+  LLVMContext Ctx;
+  Module Mod("Module", Ctx);
+  // Test that it is possible to create MachineBasicBlock for BasicBlock
+  auto &F = MF->getFunction();
+  auto *BB = BasicBlock::Create(F.getContext(), "", &F);
+  auto *MBB = MF->CreateMachineBasicBlock(BB);
+  EXPECT_NE(MBB, nullptr);
+  EXPECT_EQ(MBB->size(), 0);
+  EXPECT_EQ(MBB->getBasicBlock(), BB);
+}
 } // end namespace
diff --git a/llvm/unittests/IR/BasicBlockTest.cpp b/llvm/unittests/IR/BasicBlockTest.cpp
index 88ac6611742ce9..fbed69c3197011 100644
--- a/llvm/unittests/IR/BasicBlockTest.cpp
+++ b/llvm/unittests/IR/BasicBlockTest.cpp
@@ -583,5 +583,15 @@ TEST(BasicBlockTest, DiscardValueNames2) {
   }
 }
 
+TEST(BasicBlockTest, IrrLoopHeaderNull) {
+  LLVMContext Ctx;
+  Module M("Mod", Ctx);
+  auto *FTy = FunctionType::get(Type::getVoidTy(M.getContext()),
+                                Type::getInt32Ty(Ctx), /* isVarArg */ false);
+  auto *F = Function::Create(FTy, Function::ExternalLinkage, "foo", &M);
+  auto *BB = BasicBlock::Create(Ctx, "", F);
+  EXPECT_EQ(BB->getIrrLoopHeaderWeight(), std::nullopt);
+}
+
 } // End anonymous namespace.
 } // End llvm namespace.

``````````

</details>


https://github.com/llvm/llvm-project/pull/116192


More information about the llvm-commits mailing list