[llvm] [NVPTX] Convert calls to indirect when call signature mismatches function signature (PR #107644)

Kevin McAfee via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 6 14:35:42 PDT 2024


https://github.com/kalxr created https://github.com/llvm/llvm-project/pull/107644

When at least one of the return type, parameter type, or parameter count mismatches between a call instruction and the callee, lower the call to an indirect call. The current behavior is to produce direct calls that may or may not be valid PTX. Consider the following example with mismatching return types:

```
%struct.1 = type <{i64}>
%struct.2 = type <{i64}>
declare %struct.1 @callee()
...
%call1 = call %struct.2 @callee()
%call2 = call i64 @callee()
```

The return type of `callee` in PTX is `.b8 _[8]`. The return type of `%call1` will be the same and so the PTX has no problems. The return type of `%call2` will be `.b64`, so the types will not match and PTX will be unacceptable to ptxas. This despite all the types having the same size. The same is true for mismatching parameter types.

If we instead convert these calls to indirect calls, we will generate functional PTX when the types have the same size. If they do not have the same size then the PTX will be incorrect, though this will not necessarily be caught by ptxas. This change allows for more flexibility in the bitcode that can be lowered to functioning PTX, at the cost of sometimes producing PTX that is less clearly wrong than it would have been previously (i.e. incorrect indirect calls are not as obviously wrong as incorrect direct calls).

>From 4cc77c8f06be38af28cc8afbb803fcce4dbc96fd Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Fri, 30 Aug 2024 11:55:14 -0700
Subject: [PATCH] [NVPTX] Convert calls to indirect when call signature
 mismatches function signature

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 47 +++++++++-
 llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll |  4 +-
 .../CodeGen/NVPTX/convert-call-to-indirect.ll | 89 +++++++++++++++++++
 3 files changed, 137 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb76ffdfd99d7b..726493ccaa2569 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1657,6 +1657,33 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
   return RetVal;
 }
 
+static bool shouldConvertToIndirectCall(bool IsVarArg, unsigned ParamCount,
+                                        NVPTXTargetLowering::ArgListTy &Args,
+                                        const CallBase *CB,
+                                        GlobalAddressSDNode *Func) {
+  if (!Func)
+    return false;
+  auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal());
+  if (!CalleeFunc)
+    return false;
+
+  auto ActualReturnType = CalleeFunc->getReturnType();
+  if (CB->getType() != ActualReturnType)
+    return true;
+
+  if (IsVarArg)
+    return false;
+
+  auto ActualNumParams = CalleeFunc->getFunctionType()->getNumParams();
+  if (ParamCount != ActualNumParams)
+    return true;
+  for (const Argument &I : CalleeFunc->args())
+    if (I.getType() != Args[I.getArgNo()].Ty)
+      return true;
+
+  return false;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1971,10 +1998,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                     VADeclareParam->getVTList(), DeclareParamOps);
   }
 
+  // If the param count, type of any param, or return type of the callsite
+  // mismatches with that of the function signature, convert the callsite to an
+  // indirect call.
+  bool ConvertToIndirectCall =
+      shouldConvertToIndirectCall(CLI.IsVarArg, ParamCount, Args, CB, Func);
+
   // Both indirect calls and libcalls have nullptr Func. In order to distinguish
   // between them we must rely on the call site value which is valid for
   // indirect calls but is always null for libcalls.
-  bool isIndirectCall = !Func && CB;
+  bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
 
   if (isa<ExternalSymbolSDNode>(Callee)) {
     Function* CalleeFunc = nullptr;
@@ -2026,6 +2059,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
   InGlue = Chain.getValue(1);
 
+  if (ConvertToIndirectCall) {
+    // Copy the function ptr to a ptx register and use the register to call the
+    // function.
+    EVT DestVT = Callee.getValueType();
+    MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    unsigned DestReg =
+        RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+    auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
+    Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
+  }
+
   // Ops to print out the function name
   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
   SDValue CallVoidOps[] = { Chain, Callee, InGlue };
diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
index c5f7bd1bd1ba20..bd723a296e620f 100644
--- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -17,8 +17,8 @@ target triple = "nvptx64-nvidia-cuda"
 ; CHECK: st.param.b16   [param2+0], %rs1;
 ; CHECK: st.param.b16   [param2+2], %rs2;
 ; CHECK: .param .align 2 .b8 retval0[4];
-; CHECK: call.uni (retval0),
-; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
+; CHECK-NEXT: prototype_0 : .callprototype (.param .align 2 .b8 _[4]) _ (.param .b32 _, .param .b32 _, .param .align 2 .b8 _[4]);
+; CHECK-NEXT: call (retval0),
 define weak_odr void @foo() {
 entry:
   %call.i.i.i = tail call %"class.complex" @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32 0, i32 0, ptr byval(%"class.complex") null)
diff --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
new file mode 100644
index 00000000000000..2602c3b0d041b5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -0,0 +1,89 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+%struct.64 = type <{ i64 }>
+declare i64 @callee(ptr %p);
+declare i64 @callee_variadic(ptr %p, ...);
+
+define %struct.64 @test_return_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch(
+; CHECK:         .param .align 1 .b8 retval0[8];
+; CHECK-NEXT:    prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_0;
+  %ret = call %struct.64 @callee(ptr %p)
+  ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    prototype_1 : .callprototype (.param .b64 _) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_1;
+  %ret = call i64 @callee(i64 7)
+  ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_2;
+  %ret = call i64 @callee(ptr %p, i64 7)
+  ret i64 %ret
+}
+
+define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch_variadic(
+; CHECK:         .param .align 1 .b8 retval0[8];
+; CHECK-NEXT:    prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_3;
+  %ret = call %struct.64 (ptr, ...) @callee_variadic(ptr %p)
+  ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch_variadic(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee_variadic
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+  %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+  ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch_variadic(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee_variadic
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+  %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+  ret i64 %ret
+}



More information about the llvm-commits mailing list