[llvm] [AbstractCallSite] Handle Indirect Calls Properly (PR #163003)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Oct 11 07:46:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: Kunqiu Chen (Camsyn)
<details>
<summary>Changes</summary>
This patch fixes a bug in `AbstractCallSite` where indirect calls cause unexpected behavior and assertion failures.
#### **Problem**
When the underlying call is **indirect**, some APIs of `AbstractCallSite` behave unexpectedly.
E.g., `AbstractCallSite::getCalledFunction()` currently triggers an **assertion failure**, instead of returning `nullptr` as documented:
```cpp
/// Return the function being called if this is a direct call, otherwise
/// return null (if it's an indirect call).
Function *getCalledFunction() const;
```
This violates the intended contract of `AbstractCallSite` and makes it unsafe to use as a general call-site abstraction.
#### **Expected Behavior**
`AbstractCallSite` is designed to provide a **unified interface** for representing all the THREE kinds of call sites — including direct calls, indirect calls, and callback calls.
It should gracefully handle indirect calls.
#### **Why It Went Unnoticed**
Currently, LLVM does not contain any in-tree use cases where `AbstractCallSite` wraps an **indirect call**, so this issue has not been triggered or tested before.
#### **Motivation**
* Improves robustness and aligns implementation with the documented behavior.
* Enables future and third-party extensions that use `AbstractCallSite` to represent arbitrary call sites.
* Ensures `AbstractCallSite` can safely wrap indirect calls without breaking internal invariants.
#### **Fix**
This pull request updates the logic in the `AbstractCallSite` class to more accurately distinguish between direct, indirect, and callback calls, and adds comprehensive unit tests for each call type.
The main change is that several methods now check `if (!isCallbackCall())` rather than just `if (isDirectCall())`, ensuring correct behavior for indirect call sites.
---
Full diff: https://github.com/llvm/llvm-project/pull/163003.diff
2 Files Affected:
- (modified) llvm/include/llvm/IR/AbstractCallSite.h (+5-5)
- (modified) llvm/unittests/IR/AbstractCallSiteTest.cpp (+53)
``````````diff
diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
/// Return true if @p U is the use that defines the callee of this ACS.
bool isCallee(const Use *U) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->isCallee(U);
assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
/// Return the number of parameters of the callee.
unsigned getNumArgOperands() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->arg_size();
// Subtract 1 for the callee encoding.
return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
/// Return the operand index of the underlying instruction associated with
/// the function parameter number @p ArgNo or -1 if there is none.
int getCallArgOperandNo(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return ArgNo;
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
/// Return the operand of the underlying instruction associated with the
/// function parameter number @p ArgNo or nullptr if there is none.
Value *getCallArgOperand(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getArgOperand(ArgNo);
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
/// Return the pointer to function that is being called.
Value *getCalledOperand() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getCalledOperand();
return CB->getArgOperand(getCallArgOperandNoForCallee());
}
diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..c30515a93b339 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -53,3 +53,56 @@ TEST(AbstractCallSite, CallbackCall) {
EXPECT_TRUE(ACS.isCallee(CallbackUse));
EXPECT_EQ(ACS.getCalledFunction(), Callback);
}
+
+TEST(AbstractCallSite, DirectCall) {
+ LLVMContext C;
+
+ const char *IR = "declare void @bar()\n"
+ "define void @foo() {\n"
+ " call void @bar()\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Callee = M->getFunction("bar");
+ ASSERT_NE(Callee, nullptr);
+
+ const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+ ASSERT_NE(DirectCallUse, nullptr);
+
+ AbstractCallSite ACS(DirectCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isDirectCall());
+ EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), Callee);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+ LLVMContext C;
+
+ const char *IR = "define void @foo(ptr %0) {\n"
+ " call void %0()\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Fun = M->getFunction("foo");
+ ASSERT_NE(Fun, nullptr);
+
+ Argument *ArgAsCallee = Fun->getArg(0);
+ ASSERT_NE(ArgAsCallee, nullptr);
+
+ const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+ ASSERT_NE(IndCallUse, nullptr);
+
+ AbstractCallSite ACS(IndCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isIndirectCall());
+ EXPECT_TRUE(ACS.isCallee(IndCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+ EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/163003
More information about the llvm-commits
mailing list