[llvm] [AbstractCallSite] Handle Indirect Calls Properly (PR #163003)

via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 11 07:46:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Kunqiu Chen (Camsyn)

<details>
<summary>Changes</summary>


This patch fixes a bug in `AbstractCallSite` where indirect calls cause unexpected behavior and assertion failures.

#### **Problem**

When the underlying call is **indirect**, some APIs of `AbstractCallSite` behave unexpectedly. 
E.g., `AbstractCallSite::getCalledFunction()` currently triggers an **assertion failure**, instead of returning `nullptr` as documented:

```cpp
/// Return the function being called if this is a direct call, otherwise
/// return null (if it's an indirect call).
Function *getCalledFunction() const;
```

This violates the intended contract of `AbstractCallSite` and makes it unsafe to use as a general call-site abstraction.

#### **Expected Behavior**

`AbstractCallSite` is designed to provide a **unified interface** for representing all the THREE kinds of call sites — including direct calls, indirect calls, and callback calls.

It should gracefully handle indirect calls.

#### **Why It Went Unnoticed**

Currently, LLVM does not contain any in-tree use cases where `AbstractCallSite` wraps an **indirect call**, so this issue has not been triggered or tested before.

#### **Motivation**

* Improves robustness and aligns implementation with the documented behavior.
* Enables future and third-party extensions that use `AbstractCallSite` to represent arbitrary call sites.
* Ensures `AbstractCallSite` can safely wrap indirect calls without breaking internal invariants.

#### **Fix**

This pull request updates the logic in the `AbstractCallSite` class to more accurately distinguish between direct, indirect, and callback calls, and adds comprehensive unit tests for each call type. 

The main change is that several methods now check `if (!isCallbackCall())` rather than just `if (isDirectCall())`, ensuring correct behavior for indirect call sites.






---
Full diff: https://github.com/llvm/llvm-project/pull/163003.diff


2 Files Affected:

- (modified) llvm/include/llvm/IR/AbstractCallSite.h (+5-5) 
- (modified) llvm/unittests/IR/AbstractCallSiteTest.cpp (+53) 


``````````diff
diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
 
   /// Return true if @p U is the use that defines the callee of this ACS.
   bool isCallee(const Use *U) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->isCallee(U);
 
     assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
 
   /// Return the number of parameters of the callee.
   unsigned getNumArgOperands() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->arg_size();
     // Subtract 1 for the callee encoding.
     return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
   /// Return the operand index of the underlying instruction associated with
   /// the function parameter number @p ArgNo or -1 if there is none.
   int getCallArgOperandNo(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return ArgNo;
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
   /// Return the operand of the underlying instruction associated with the
   /// function parameter number @p ArgNo or nullptr if there is none.
   Value *getCallArgOperand(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getArgOperand(ArgNo);
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
 
   /// Return the pointer to function that is being called.
   Value *getCalledOperand() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getCalledOperand();
     return CB->getArgOperand(getCallArgOperandNoForCallee());
   }
diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..c30515a93b339 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -53,3 +53,56 @@ TEST(AbstractCallSite, CallbackCall) {
   EXPECT_TRUE(ACS.isCallee(CallbackUse));
   EXPECT_EQ(ACS.getCalledFunction(), Callback);
 }
+
+TEST(AbstractCallSite, DirectCall) {
+  LLVMContext C;
+
+  const char *IR = "declare void @bar()\n"
+                   "define void @foo() {\n"
+                   "  call void @bar()\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Callee = M->getFunction("bar");
+  ASSERT_NE(Callee, nullptr);
+
+  const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+  ASSERT_NE(DirectCallUse, nullptr);
+
+  AbstractCallSite ACS(DirectCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isDirectCall());
+  EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), Callee);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+  LLVMContext C;
+
+  const char *IR = "define void @foo(ptr %0) {\n"
+                   "  call void %0()\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Fun = M->getFunction("foo");
+  ASSERT_NE(Fun, nullptr);
+
+  Argument *ArgAsCallee = Fun->getArg(0);
+  ASSERT_NE(ArgAsCallee, nullptr);
+
+  const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+  ASSERT_NE(IndCallUse, nullptr);
+
+  AbstractCallSite ACS(IndCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isIndirectCall());
+  EXPECT_TRUE(ACS.isCallee(IndCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+  EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/163003


More information about the llvm-commits mailing list