[llvm] [AbstractCallSite] Handle Indirect Calls Properly (PR #163003)

Kunqiu Chen via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 01:35:50 PDT 2025


https://github.com/Camsyn updated https://github.com/llvm/llvm-project/pull/163003

>From ca713e4a4624e289252296db1e0de2fbb5e9bd3a Mon Sep 17 00:00:00 2001
From: Camsyn <camsyn at foxmail.com>
Date: Sat, 11 Oct 2025 22:34:23 +0800
Subject: [PATCH 1/2] [AbstractCallSite] Handle Indirect Calls Properly

---
 llvm/include/llvm/IR/AbstractCallSite.h    | 10 ++--
 llvm/unittests/IR/AbstractCallSiteTest.cpp | 53 ++++++++++++++++++++++
 2 files changed, 58 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
 
   /// Return true if @p U is the use that defines the callee of this ACS.
   bool isCallee(const Use *U) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->isCallee(U);
 
     assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
 
   /// Return the number of parameters of the callee.
   unsigned getNumArgOperands() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->arg_size();
     // Subtract 1 for the callee encoding.
     return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
   /// Return the operand index of the underlying instruction associated with
   /// the function parameter number @p ArgNo or -1 if there is none.
   int getCallArgOperandNo(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return ArgNo;
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
   /// Return the operand of the underlying instruction associated with the
   /// function parameter number @p ArgNo or nullptr if there is none.
   Value *getCallArgOperand(unsigned ArgNo) const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getArgOperand(ArgNo);
     // Add 1 for the callee encoding.
     return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
 
   /// Return the pointer to function that is being called.
   Value *getCalledOperand() const {
-    if (isDirectCall())
+    if (!isCallbackCall())
       return CB->getCalledOperand();
     return CB->getArgOperand(getCallArgOperandNoForCallee());
   }
diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..c30515a93b339 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -53,3 +53,56 @@ TEST(AbstractCallSite, CallbackCall) {
   EXPECT_TRUE(ACS.isCallee(CallbackUse));
   EXPECT_EQ(ACS.getCalledFunction(), Callback);
 }
+
+TEST(AbstractCallSite, DirectCall) {
+  LLVMContext C;
+
+  const char *IR = "declare void @bar()\n"
+                   "define void @foo() {\n"
+                   "  call void @bar()\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Callee = M->getFunction("bar");
+  ASSERT_NE(Callee, nullptr);
+
+  const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+  ASSERT_NE(DirectCallUse, nullptr);
+
+  AbstractCallSite ACS(DirectCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isDirectCall());
+  EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), Callee);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+  LLVMContext C;
+
+  const char *IR = "define void @foo(ptr %0) {\n"
+                   "  call void %0()\n"
+                   "  ret void\n"
+                   "}\n";
+
+  std::unique_ptr<Module> M = parseIR(C, IR);
+  ASSERT_TRUE(M);
+
+  Function *Fun = M->getFunction("foo");
+  ASSERT_NE(Fun, nullptr);
+
+  Argument *ArgAsCallee = Fun->getArg(0);
+  ASSERT_NE(ArgAsCallee, nullptr);
+
+  const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+  ASSERT_NE(IndCallUse, nullptr);
+
+  AbstractCallSite ACS(IndCallUse);
+  EXPECT_TRUE(ACS);
+  EXPECT_TRUE(ACS.isIndirectCall());
+  EXPECT_TRUE(ACS.isCallee(IndCallUse));
+  EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+  EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+}

>From 1f6797a7f268748b23a041608c24d29b31f74ec8 Mon Sep 17 00:00:00 2001
From: Camsyn <camsyn at foxmail.com>
Date: Tue, 28 Oct 2025 16:35:29 +0800
Subject: [PATCH 2/2] Check more methods

---
 llvm/unittests/IR/AbstractCallSiteTest.cpp | 47 ++++++++++++++++++++--
 1 file changed, 43 insertions(+), 4 deletions(-)

diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index c30515a93b339..36aec5bdb114c 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -6,8 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/AbstractCallSite.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/SourceMgr.h"
@@ -51,15 +52,27 @@ TEST(AbstractCallSite, CallbackCall) {
   EXPECT_TRUE(ACS);
   EXPECT_TRUE(ACS.isCallbackCall());
   EXPECT_TRUE(ACS.isCallee(CallbackUse));
+  EXPECT_EQ(ACS.getCalleeUseForCallback(), *CallbackUse);
   EXPECT_EQ(ACS.getCalledFunction(), Callback);
+
+  // The callback metadata {CallbackNo, Arg0No, ..., isVarArg} = {1, -1, true}
+  EXPECT_EQ(ACS.getCallArgOperandNoForCallee(), 1);
+  // Though the callback metadata only specifies ONE unfixed argument No, the
+  // callback callee is vararg, hence the third arg is also considered as
+  // another arg for the callback.
+  EXPECT_EQ(ACS.getNumArgOperands(), 2);
+  Argument *Param0 = Callback->getArg(0), *Param1 = Callback->getArg(1);
+  ASSERT_TRUE(Param0 && Param1);
+  EXPECT_EQ(ACS.getCallArgOperandNo(*Param0), -1);
+  EXPECT_EQ(ACS.getCallArgOperandNo(*Param1), 2);
 }
 
 TEST(AbstractCallSite, DirectCall) {
   LLVMContext C;
 
-  const char *IR = "declare void @bar()\n"
+  const char *IR = "declare void @bar(i32 %x, i32 %y)\n"
                    "define void @foo() {\n"
-                   "  call void @bar()\n"
+                   "  call void @bar(i32 1, i32 2)\n"
                    "  ret void\n"
                    "}\n";
 
@@ -77,13 +90,30 @@ TEST(AbstractCallSite, DirectCall) {
   EXPECT_TRUE(ACS.isDirectCall());
   EXPECT_TRUE(ACS.isCallee(DirectCallUse));
   EXPECT_EQ(ACS.getCalledFunction(), Callee);
+  EXPECT_EQ(ACS.getNumArgOperands(), 2);
+  Argument *ArgX = Callee->getArg(0);
+  ASSERT_NE(ArgX, nullptr);
+  Value *CAO1 = ACS.getCallArgOperand(*ArgX);
+  Value *CAO2 = ACS.getCallArgOperand(0);
+  ASSERT_NE(CAO2, nullptr);
+  // The two call arg operands should be the same object, since they are both
+  // the first argument of the call.
+  EXPECT_EQ(CAO2, CAO1);
+
+  ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CAO2);
+  ASSERT_NE(FirstArgInt, nullptr);
+  EXPECT_EQ(FirstArgInt->getZExtValue(), 1);
+
+  EXPECT_EQ(ACS.getCallArgOperandNo(*ArgX), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
 }
 
 TEST(AbstractCallSite, IndirectCall) {
   LLVMContext C;
 
   const char *IR = "define void @foo(ptr %0) {\n"
-                   "  call void %0()\n"
+                   "  call void %0(i32 1, i32 2)\n"
                    "  ret void\n"
                    "}\n";
 
@@ -105,4 +135,13 @@ TEST(AbstractCallSite, IndirectCall) {
   EXPECT_TRUE(ACS.isCallee(IndCallUse));
   EXPECT_EQ(ACS.getCalledFunction(), nullptr);
   EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+  EXPECT_EQ(ACS.getNumArgOperands(), 2);
+  Value *CalledOperand = ACS.getCallArgOperand(0);
+  ASSERT_NE(CalledOperand, nullptr);
+  ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CalledOperand);
+  ASSERT_NE(FirstArgInt, nullptr);
+  EXPECT_EQ(FirstArgInt->getZExtValue(), 1);
+
+  EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
+  EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
 }



More information about the llvm-commits mailing list