[llvm] Add new helper `Module::getRequiredFunction`. (PR #82761)

Thomas Symalla via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 23 05:13:20 PST 2024


https://github.com/tsymalla created https://github.com/llvm/llvm-project/pull/82761

Sometimes the existance of a function on a module is strictly required, thus the caller of `Module::getFunction` needs to handle the possible bailing out.

Add a new helper, `Module::getRequiredFunction`, that adds an additional check.

In the `RewriteDescriptor`, the presence of only a single argument to `getFunction` is required, thus adding a new helper function.

Since we cannot overload functions on the basis of a default value for an argument, I just added a new helper.

>From fb9a75f9b1a8a56b85455e32dda6cdf2d01d3255 Mon Sep 17 00:00:00 2001
From: Thomas Symalla <tsymalla at amd.com>
Date: Fri, 23 Feb 2024 13:31:35 +0100
Subject: [PATCH] Add new helper `Module::getRequiredFunction`.

Sometimes the existance of a function on a module is strictly required,
thus the caller of `Module::getFunction` needs to handle the possible
bailing out.

Add a new helper, `Module::getRequiredFunction`, that adds an additional
check.

In the `RewriteDescriptor`, the presence of only a single argument to
`getFunction` is required, thus adding a new helper function.

Since we cannot overload functions on the basis of a default value for
an argument, I just added a new helper.
---
 llvm/include/llvm/IR/Module.h               |  4 ++++
 llvm/lib/IR/Module.cpp                      | 12 ++++++++++++
 llvm/unittests/Analysis/CFGTest.cpp         |  4 +---
 llvm/unittests/Analysis/VectorUtilsTest.cpp |  4 +---
 4 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/IR/Module.h b/llvm/include/llvm/IR/Module.h
index 68a89dc45c2834..db4df92cd135d5 100644
--- a/llvm/include/llvm/IR/Module.h
+++ b/llvm/include/llvm/IR/Module.h
@@ -424,6 +424,10 @@ class LLVM_EXTERNAL_VISIBILITY Module {
   /// exist, return null.
   Function *getFunction(StringRef Name) const;
 
+  /// Look up the specified function in the module symbol table.
+  /// If the function does not exist, then bail out.
+  Function *getRequiredFunction(StringRef Name) const;
+
 /// @}
 /// @name Global Variable Accessors
 /// @{
diff --git a/llvm/lib/IR/Module.cpp b/llvm/lib/IR/Module.cpp
index 1946db2ee0be7e..1d40ba083c789f 100644
--- a/llvm/lib/IR/Module.cpp
+++ b/llvm/lib/IR/Module.cpp
@@ -170,6 +170,18 @@ Function *Module::getFunction(StringRef Name) const {
   return dyn_cast_or_null<Function>(getNamedValue(Name));
 }
 
+// getRequiredFunction - Look up the specified function in the module symbol
+// table. If the function does not exist, bail out.
+//
+Function *Module::getRequiredFunction(StringRef Name) const {
+  Function *F = getFunction(Name);
+  if (!F)
+    report_fatal_error(Twine("Required function '@") + Name +
+                       "' not found on module '" + getName() + "'!");
+
+  return F;
+}
+
 //===----------------------------------------------------------------------===//
 // Methods for easy access to the global variables in the module.
 //
diff --git a/llvm/unittests/Analysis/CFGTest.cpp b/llvm/unittests/Analysis/CFGTest.cpp
index 46164268468628..c822ff4b52f7fd 100644
--- a/llvm/unittests/Analysis/CFGTest.cpp
+++ b/llvm/unittests/Analysis/CFGTest.cpp
@@ -41,9 +41,7 @@ class IsPotentiallyReachableTest : public testing::Test {
     if (!M)
       report_fatal_error(os.str().c_str());
 
-    Function *F = M->getFunction("test");
-    if (F == nullptr)
-      report_fatal_error("Test must have a function named @test");
+    Function *F = M->getRequiredFunction("test");
 
     A = B = nullptr;
     for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index 14958aa646a04d..08d7a85ef4fcb0 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -37,9 +37,7 @@ class VectorUtilsTest : public testing::Test {
     if (!M)
       report_fatal_error(Twine(os.str()));
 
-    Function *F = M->getFunction("test");
-    if (F == nullptr)
-      report_fatal_error("Test must have a function named @test");
+    Function *F = M->getRequiredFunction("test");
 
     A = nullptr;
     for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {



More information about the llvm-commits mailing list