[llvm] a847134 - [AbstractCallSite] Handle Indirect Calls Properly (#163003)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 28 06:45:21 PDT 2025
Author: Kunqiu Chen
Date: 2025-10-28T21:45:17+08:00
New Revision: a8471342fae518796232208929a23c2b8a127a68
URL: https://github.com/llvm/llvm-project/commit/a8471342fae518796232208929a23c2b8a127a68
DIFF: https://github.com/llvm/llvm-project/commit/a8471342fae518796232208929a23c2b8a127a68.diff
LOG: [AbstractCallSite] Handle Indirect Calls Properly (#163003)
AbstractCallSite handles three types of calls (direct, indirect, and
callback).
This patch fixes the handling of indirect calls in some methods, which
incorrectly assumed that non-direct calls are always callback calls.
Moreover, this PR adds 2 unit tests for direct call type and indirect
call type.
The aforementioned misassumption leads to the following problem:
---
## Problem
When the underlying call is **indirect**, some APIs of
`AbstractCallSite` behave unexpectedly.
E.g., `AbstractCallSite::getCalledFunction()` currently triggers an
**assertion failure**, instead of returning `nullptr` as documented:
```cpp
/// Return the function being called if this is a direct call, otherwise
/// return null (if it's an indirect call).
Function *getCalledFunction() const;
```
Actual unexpected assertion failure:
```
AbstractCallSite.h:197: int llvm::AbstractCallSite::getCallArgOperandNoForCallee() const: Assertion `isCallbackCall()' failed.
```
This is because `AbstractCallSite` mistakenly entered the branch that
handles Callback Calls as its guard condition (`!isDirectCall()`) does
not take into account the case of indirect calls
Added:
Modified:
llvm/include/llvm/IR/AbstractCallSite.h
llvm/unittests/IR/AbstractCallSiteTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
/// Return true if @p U is the use that defines the callee of this ACS.
bool isCallee(const Use *U) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->isCallee(U);
assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
/// Return the number of parameters of the callee.
unsigned getNumArgOperands() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->arg_size();
// Subtract 1 for the callee encoding.
return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
/// Return the operand index of the underlying instruction associated with
/// the function parameter number @p ArgNo or -1 if there is none.
int getCallArgOperandNo(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return ArgNo;
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
/// Return the operand of the underlying instruction associated with the
/// function parameter number @p ArgNo or nullptr if there is none.
Value *getCallArgOperand(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getArgOperand(ArgNo);
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
/// Return the pointer to function that is being called.
Value *getCalledOperand() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getCalledOperand();
return CB->getArgOperand(getCallArgOperandNoForCallee());
}
diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..623d1b36e1c03 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/AbstractCallSite.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Argument.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
@@ -51,5 +52,96 @@ TEST(AbstractCallSite, CallbackCall) {
EXPECT_TRUE(ACS);
EXPECT_TRUE(ACS.isCallbackCall());
EXPECT_TRUE(ACS.isCallee(CallbackUse));
+ EXPECT_EQ(ACS.getCalleeUseForCallback(), *CallbackUse);
EXPECT_EQ(ACS.getCalledFunction(), Callback);
+
+ // The callback metadata {CallbackNo, Arg0No, ..., isVarArg} = {1, -1, true}
+ EXPECT_EQ(ACS.getCallArgOperandNoForCallee(), 1);
+ // Though the callback metadata only specifies ONE unfixed argument No, the
+ // callback callee is vararg, hence the third arg is also considered as
+ // another arg for the callback.
+ EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+ Argument *Param0 = Callback->getArg(0), *Param1 = Callback->getArg(1);
+ ASSERT_TRUE(Param0 && Param1);
+ EXPECT_EQ(ACS.getCallArgOperandNo(*Param0), -1);
+ EXPECT_EQ(ACS.getCallArgOperandNo(*Param1), 2);
+}
+
+TEST(AbstractCallSite, DirectCall) {
+ LLVMContext C;
+
+ const char *IR = "declare void @bar(i32 %x, i32 %y)\n"
+ "define void @foo() {\n"
+ " call void @bar(i32 1, i32 2)\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Callee = M->getFunction("bar");
+ ASSERT_NE(Callee, nullptr);
+
+ const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+ ASSERT_NE(DirectCallUse, nullptr);
+
+ AbstractCallSite ACS(DirectCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isDirectCall());
+ EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), Callee);
+ EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+ Argument *ArgX = Callee->getArg(0);
+ ASSERT_NE(ArgX, nullptr);
+ Value *CAO1 = ACS.getCallArgOperand(*ArgX);
+ Value *CAO2 = ACS.getCallArgOperand(0);
+ ASSERT_NE(CAO2, nullptr);
+ // The two call arg operands should be the same object, since they are both
+ // the first argument of the call.
+ EXPECT_EQ(CAO2, CAO1);
+
+ ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CAO2);
+ ASSERT_NE(FirstArgInt, nullptr);
+ EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
+
+ EXPECT_EQ(ACS.getCallArgOperandNo(*ArgX), 0);
+ EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+ EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+ LLVMContext C;
+
+ const char *IR = "define void @foo(ptr %0) {\n"
+ " call void %0(i32 1, i32 2)\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Fun = M->getFunction("foo");
+ ASSERT_NE(Fun, nullptr);
+
+ Argument *ArgAsCallee = Fun->getArg(0);
+ ASSERT_NE(ArgAsCallee, nullptr);
+
+ const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+ ASSERT_NE(IndCallUse, nullptr);
+
+ AbstractCallSite ACS(IndCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isIndirectCall());
+ EXPECT_TRUE(ACS.isCallee(IndCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+ EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+ EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+ Value *CalledOperand = ACS.getCallArgOperand(0);
+ ASSERT_NE(CalledOperand, nullptr);
+ ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CalledOperand);
+ ASSERT_NE(FirstArgInt, nullptr);
+ EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
+
+ EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+ EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
}
More information about the llvm-commits
mailing list