[llvm] a847134 - [AbstractCallSite] Handle Indirect Calls Properly (#163003)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 06:45:21 PDT 2025


Author: Kunqiu Chen
Date: 2025-10-28T21:45:17+08:00
New Revision: a8471342fae518796232208929a23c2b8a127a68

URL: https://github.com/llvm/llvm-project/commit/a8471342fae518796232208929a23c2b8a127a68
DIFF: https://github.com/llvm/llvm-project/commit/a8471342fae518796232208929a23c2b8a127a68.diff

LOG: [AbstractCallSite] Handle Indirect Calls Properly (#163003)

AbstractCallSite handles three types of calls (direct, indirect, and
callback).

This patch fixes the handling of indirect calls in some methods, which
incorrectly assumed that non-direct calls are always callback calls.

Moreover, this PR adds 2 unit tests for direct call type and indirect
call type.

The aforementioned misassumption leads to the following problem:

---
## Problem

When the underlying call is **indirect**, some APIs of
`AbstractCallSite` behave unexpectedly.
E.g., `AbstractCallSite::getCalledFunction()` currently triggers an
**assertion failure**, instead of returning `nullptr` as documented:

```cpp
/// Return the function being called if this is a direct call, otherwise
/// return null (if it's an indirect call).
Function *getCalledFunction() const;
```

Actual unexpected assertion failure:
```
AbstractCallSite.h:197: int llvm::AbstractCallSite::getCallArgOperandNoForCallee() const: Assertion `isCallbackCall()' failed.
```

This is because `AbstractCallSite` mistakenly entered the branch that
handles Callback Calls as its guard condition (`!isDirectCall()`) does
not take into account the case of indirect calls

Added: 
    

Modified: 
    llvm/include/llvm/IR/AbstractCallSite.h
    llvm/unittests/IR/AbstractCallSiteTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
 
   /// Return true if @p U is the use that defines the callee of this ACS.
   bool isCallee(const Use *U) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->isCallee(U);
 
     assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
 
   /// Return the number of parameters of the callee.
   unsigned getNumArgOperands() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->arg_size();
     // Subtract 1 for the callee encoding.
     return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
   /// Return the operand index of the underlying instruction associated with
   /// the function parameter number @p ArgNo or -1 if there is none.
   int getCallArgOperandNo(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return ArgNo;
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
   /// Return the operand of the underlying instruction associated with the
   /// function parameter number @p ArgNo or nullptr if there is none.
   Value *getCallArgOperand(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getArgOperand(ArgNo);
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
 
   /// Return the pointer to function that is being called.
   Value *getCalledOperand() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getCalledOperand();
     return CB->getArgOperand(getCallArgOperandNoForCallee());
   }

diff  --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..623d1b36e1c03 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -6,8 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/AbstractCallSite.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/SourceMgr.h"
@@ -51,5 +52,96 @@ TEST(AbstractCallSite, CallbackCall) {
   EXPECT_TRUE(ACS);
   EXPECT_TRUE(ACS.isCallbackCall());
   EXPECT_TRUE(ACS.isCallee(CallbackUse));
+  EXPECT_EQ(ACS.getCalleeUseForCallback(), *CallbackUse);
   EXPECT_EQ(ACS.getCalledFunction(), Callback);
+
+  // The callback metadata {CallbackNo, Arg0No, ..., isVarArg} = {1, -1, true}
+  EXPECT_EQ(ACS.getCallArgOperandNoForCallee(), 1);
+  // Though the callback metadata only specifies ONE unfixed argument No, the
+  // callback callee is vararg, hence the third arg is also considered as
+  // another arg for the callback.
+  EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+  Argument *Param0 = Callback->getArg(0), *Param1 = Callback->getArg(1);
+  ASSERT_TRUE(Param0 && Param1);
+  EXPECT_EQ(ACS.getCallArgOperandNo(*Param0), -1);
+  EXPECT_EQ(ACS.getCallArgOperandNo(*Param1), 2);
+}
+
+TEST(AbstractCallSite, DirectCall) {
+  LLVMContext C;
+
+  const char *IR = "declare void @bar(i32 %x, i32 %y)\n"
+                   "define void @foo() {\n"
+                   "  call void @bar(i32 1, i32 2)\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Callee = M->getFunction("bar");
+  ASSERT_NE(Callee, nullptr);
+
+  const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+  ASSERT_NE(DirectCallUse, nullptr);
+
+  AbstractCallSite ACS(DirectCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isDirectCall());
+  EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), Callee);
+  EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+  Argument *ArgX = Callee->getArg(0);
+  ASSERT_NE(ArgX, nullptr);
+  Value *CAO1 = ACS.getCallArgOperand(*ArgX);
+  Value *CAO2 = ACS.getCallArgOperand(0);
+  ASSERT_NE(CAO2, nullptr);
+  // The two call arg operands should be the same object, since they are both
+  // the first argument of the call.
+  EXPECT_EQ(CAO2, CAO1);
+
+  ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CAO2);
+  ASSERT_NE(FirstArgInt, nullptr);
+  EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
+
+  EXPECT_EQ(ACS.getCallArgOperandNo(*ArgX), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+  LLVMContext C;
+
+  const char *IR = "define void @foo(ptr %0) {\n"
+                   "  call void %0(i32 1, i32 2)\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Fun = M->getFunction("foo");
+  ASSERT_NE(Fun, nullptr);
+
+  Argument *ArgAsCallee = Fun->getArg(0);
+  ASSERT_NE(ArgAsCallee, nullptr);
+
+  const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+  ASSERT_NE(IndCallUse, nullptr);
+
+  AbstractCallSite ACS(IndCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isIndirectCall());
+  EXPECT_TRUE(ACS.isCallee(IndCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+  EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+  EXPECT_EQ(ACS.getNumArgOperands(), 2u);
+  Value *CalledOperand = ACS.getCallArgOperand(0);
+  ASSERT_NE(CalledOperand, nullptr);
+  ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CalledOperand);
+  ASSERT_NE(FirstArgInt, nullptr);
+  EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
+
+  EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
 }


        


More information about the llvm-commits mailing list