[llvm] 62cdc2a - [NVPTX] Convert calls to indirect when call signature mismatches function signature (#107644)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 16 13:08:22 PDT 2024


Author: Kevin McAfee
Date: 2024-09-16T13:08:18-07:00
New Revision: 62cdc2a347584f32dd9c351b2384c873da0f32ad

URL: https://github.com/llvm/llvm-project/commit/62cdc2a347584f32dd9c351b2384c873da0f32ad
DIFF: https://github.com/llvm/llvm-project/commit/62cdc2a347584f32dd9c351b2384c873da0f32ad.diff

LOG: [NVPTX] Convert calls to indirect when call signature mismatches function signature (#107644)

When there is a function signature mismatch between a call
instruction and the callee, lower the call to an indirect call. The
current behavior is to produce direct calls that may or may not be
valid PTX. Consider the following example with mismatching return
types:

```
%struct.1 = type <{i64}>
%struct.2 = type <{i64}>
declare %struct.1 @callee()
...
%call1 = call %struct.2 @callee()
%call2 = call i64 @callee()
```

The return type of `callee` in PTX is `.b8 _[8]`. The return type of
`%call1` will be the same and so the PTX has no problems. The return
type of `%call2` will be `.b64`, so the types will not match and PTX
will be unacceptable to ptxas. This despite all the types having the
same size. The same is true for mismatching parameter types.

If we instead convert these calls to indirect calls, we will generate
functional PTX when the types have the same size. If they do not have
the same size then the PTX will likely be incorrect, though this will not
necessarily be caught by ptxas. Also, even if the sizes are the same, if
the types differ then it is technically undefined behavior. This change
allows for more flexibility in the bitcode that can be lowered to
functioning PTX, at the cost of sometimes producing PTX that is less
clearly wrong than it would have been previously (i.e. incorrect indirect
calls are not as obviously wrong as incorrect direct calls). We consider
it okay to generate PTX with undefined behavior as the behavior of
calls with mismatching types is not explicitly defined.

Added: 
    llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
    llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb4fc802063a57..c5a40e4308860c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1658,6 +1658,15 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
   return RetVal;
 }
 
+static bool shouldConvertToIndirectCall(const CallBase *CB,
+                                        const GlobalAddressSDNode *Func) {
+  if (!Func)
+    return false;
+  if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
+    return CB->getFunctionType() != CalleeFunc->getFunctionType();
+  return false;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1972,10 +1981,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                     VADeclareParam->getVTList(), DeclareParamOps);
   }
 
+  // If the type of the callsite does not match that of the function, convert
+  // the callsite to an indirect call.
+  bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
+
   // Both indirect calls and libcalls have nullptr Func. In order to distinguish
   // between them we must rely on the call site value which is valid for
   // indirect calls but is always null for libcalls.
-  bool isIndirectCall = !Func && CB;
+  bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
 
   if (isa<ExternalSymbolSDNode>(Callee)) {
     Function* CalleeFunc = nullptr;
@@ -2027,6 +2040,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
   InGlue = Chain.getValue(1);
 
+  if (ConvertToIndirectCall) {
+    // Copy the function ptr to a ptx register and use the register to call the
+    // function.
+    EVT DestVT = Callee.getValueType();
+    MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    unsigned DestReg =
+        RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+    auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
+    Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
+  }
+
   // Ops to print out the function name
   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
   SDValue CallVoidOps[] = { Chain, Callee, InGlue };

diff  --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
index c5f7bd1bd1ba20..bd723a296e620f 100644
--- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -17,8 +17,8 @@ target triple = "nvptx64-nvidia-cuda"
 ; CHECK: st.param.b16   [param2+0], %rs1;
 ; CHECK: st.param.b16   [param2+2], %rs2;
 ; CHECK: .param .align 2 .b8 retval0[4];
-; CHECK: call.uni (retval0),
-; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
+; CHECK-NEXT: prototype_0 : .callprototype (.param .align 2 .b8 _[4]) _ (.param .b32 _, .param .b32 _, .param .align 2 .b8 _[4]);
+; CHECK-NEXT: call (retval0),
 define weak_odr void @foo() {
 entry:
   %call.i.i.i = tail call %"class.complex" @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32 0, i32 0, ptr byval(%"class.complex") null)

diff  --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
new file mode 100644
index 00000000000000..2602c3b0d041b5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -0,0 +1,89 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+%struct.64 = type <{ i64 }>
+declare i64 @callee(ptr %p);
+declare i64 @callee_variadic(ptr %p, ...);
+
+define %struct.64 @test_return_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch(
+; CHECK:         .param .align 1 .b8 retval0[8];
+; CHECK-NEXT:    prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_0;
+  %ret = call %struct.64 @callee(ptr %p)
+  ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    prototype_1 : .callprototype (.param .b64 _) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_1;
+  %ret = call i64 @callee(i64 7)
+  ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_2;
+  %ret = call i64 @callee(ptr %p, i64 7)
+  ret i64 %ret
+}
+
+define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch_variadic(
+; CHECK:         .param .align 1 .b8 retval0[8];
+; CHECK-NEXT:    prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT:    call (retval0),
+; CHECK-NEXT:    %rd
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    )
+; CHECK-NEXT:    , prototype_3;
+  %ret = call %struct.64 (ptr, ...) @callee_variadic(ptr %p)
+  ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch_variadic(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee_variadic
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+  %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+  ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch_variadic(
+; CHECK:         .param .b64 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee_variadic
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0,
+; CHECK-NEXT:    param1
+; CHECK-NEXT:    )
+  %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+  ret i64 %ret
+}

diff  --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 176dfee11cfb09..b203a78d677308 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -72,21 +72,24 @@ define void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-LABEL: grid_const_escape(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<3>;
-; PTX-NEXT:    .reg .b64 %rd<4>;
+; PTX-NEXT:    .reg .b64 %rd<5>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd1, grid_const_escape_param_0;
-; PTX-NEXT:    mov.u64 %rd2, %rd1;
-; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
+; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
+; PTX-NEXT:    mov.u64 %rd3, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
+; PTX-NEXT:    mov.u64 %rd1, escape;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0+0], %rd3;
+; PTX-NEXT:    st.param.b64 [param0+0], %rd4;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    call.uni (retval0),
-; PTX-NEXT:    escape,
+; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
+; PTX-NEXT:    call (retval0),
+; PTX-NEXT:    %rd1,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
-; PTX-NEXT:    );
+; PTX-NEXT:    )
+; PTX-NEXT:    , prototype_0;
 ; PTX-NEXT:    ld.param.b32 %r1, [retval0+0];
 ; PTX-NEXT:    } // callseq 0
 ; PTX-NEXT:    ret;
@@ -107,36 +110,39 @@ define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<4>;
-; PTX-NEXT:    .reg .b64 %rd<9>;
+; PTX-NEXT:    .reg .b64 %rd<10>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.u64 %SPL, __local_depot3;
 ; PTX-NEXT:    cvta.local.u64 %SP, %SPL;
-; PTX-NEXT:    mov.b64 %rd1, multiple_grid_const_escape_param_0;
-; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_2;
-; PTX-NEXT:    mov.u64 %rd3, %rd2;
+; PTX-NEXT:    mov.b64 %rd2, multiple_grid_const_escape_param_0;
+; PTX-NEXT:    mov.b64 %rd3, multiple_grid_const_escape_param_2;
+; PTX-NEXT:    mov.u64 %rd4, %rd3;
 ; PTX-NEXT:    ld.param.u32 %r1, [multiple_grid_const_escape_param_1];
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
-; PTX-NEXT:    mov.u64 %rd5, %rd1;
-; PTX-NEXT:    cvta.param.u64 %rd6, %rd5;
-; PTX-NEXT:    add.u64 %rd7, %SP, 0;
-; PTX-NEXT:    add.u64 %rd8, %SPL, 0;
-; PTX-NEXT:    st.local.u32 [%rd8], %r1;
+; PTX-NEXT:    cvta.param.u64 %rd5, %rd4;
+; PTX-NEXT:    mov.u64 %rd6, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd7, %rd6;
+; PTX-NEXT:    add.u64 %rd8, %SP, 0;
+; PTX-NEXT:    add.u64 %rd9, %SPL, 0;
+; PTX-NEXT:    st.local.u32 [%rd9], %r1;
+; PTX-NEXT:    mov.u64 %rd1, escape3;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0+0], %rd6;
+; PTX-NEXT:    st.param.b64 [param0+0], %rd7;
 ; PTX-NEXT:    .param .b64 param1;
-; PTX-NEXT:    st.param.b64 [param1+0], %rd7;
+; PTX-NEXT:    st.param.b64 [param1+0], %rd8;
 ; PTX-NEXT:    .param .b64 param2;
-; PTX-NEXT:    st.param.b64 [param2+0], %rd4;
+; PTX-NEXT:    st.param.b64 [param2+0], %rd5;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    call.uni (retval0),
-; PTX-NEXT:    escape3,
+; PTX-NEXT:    prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
+; PTX-NEXT:    call (retval0),
+; PTX-NEXT:    %rd1,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0,
 ; PTX-NEXT:    param1,
 ; PTX-NEXT:    param2
-; PTX-NEXT:    );
+; PTX-NEXT:    )
+; PTX-NEXT:    , prototype_1;
 ; PTX-NEXT:    ld.param.b32 %r2, [retval0+0];
 ; PTX-NEXT:    } // callseq 1
 ; PTX-NEXT:    ret;
@@ -221,26 +227,29 @@ define void @grid_const_partial_escape(ptr byval(i32) %input, ptr %output) {
 ; PTX-LABEL: grid_const_partial_escape(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<5>;
-; PTX-NEXT:    .reg .b64 %rd<6>;
+; PTX-NEXT:    .reg .b64 %rd<7>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escape_param_0;
-; PTX-NEXT:    ld.param.u64 %rd2, [grid_const_partial_escape_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
-; PTX-NEXT:    mov.u64 %rd4, %rd1;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd4;
-; PTX-NEXT:    ld.u32 %r1, [%rd5];
+; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escape_param_0;
+; PTX-NEXT:    ld.param.u64 %rd3, [grid_const_partial_escape_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
+; PTX-NEXT:    mov.u64 %rd5, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd6, %rd5;
+; PTX-NEXT:    ld.u32 %r1, [%rd6];
 ; PTX-NEXT:    add.s32 %r2, %r1, %r1;
-; PTX-NEXT:    st.global.u32 [%rd3], %r2;
+; PTX-NEXT:    st.global.u32 [%rd4], %r2;
+; PTX-NEXT:    mov.u64 %rd1, escape;
 ; PTX-NEXT:    { // callseq 2, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0+0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0+0], %rd6;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    call.uni (retval0),
-; PTX-NEXT:    escape,
+; PTX-NEXT:    prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
+; PTX-NEXT:    call (retval0),
+; PTX-NEXT:    %rd1,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
-; PTX-NEXT:    );
+; PTX-NEXT:    )
+; PTX-NEXT:    , prototype_2;
 ; PTX-NEXT:    ld.param.b32 %r3, [retval0+0];
 ; PTX-NEXT:    } // callseq 2
 ; PTX-NEXT:    ret;
@@ -266,27 +275,30 @@ define i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input, ptr %outpu
 ; PTX-LABEL: grid_const_partial_escapemem(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<6>;
-; PTX-NEXT:    .reg .b64 %rd<6>;
+; PTX-NEXT:    .reg .b64 %rd<7>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
-; PTX-NEXT:    mov.b64 %rd1, grid_const_partial_escapemem_param_0;
-; PTX-NEXT:    ld.param.u64 %rd2, [grid_const_partial_escapemem_param_1];
-; PTX-NEXT:    cvta.to.global.u64 %rd3, %rd2;
-; PTX-NEXT:    mov.u64 %rd4, %rd1;
-; PTX-NEXT:    cvta.param.u64 %rd5, %rd4;
-; PTX-NEXT:    ld.u32 %r1, [%rd5];
-; PTX-NEXT:    ld.u32 %r2, [%rd5+4];
-; PTX-NEXT:    st.global.u64 [%rd3], %rd5;
+; PTX-NEXT:    mov.b64 %rd2, grid_const_partial_escapemem_param_0;
+; PTX-NEXT:    ld.param.u64 %rd3, [grid_const_partial_escapemem_param_1];
+; PTX-NEXT:    cvta.to.global.u64 %rd4, %rd3;
+; PTX-NEXT:    mov.u64 %rd5, %rd2;
+; PTX-NEXT:    cvta.param.u64 %rd6, %rd5;
+; PTX-NEXT:    ld.u32 %r1, [%rd6];
+; PTX-NEXT:    ld.u32 %r2, [%rd6+4];
+; PTX-NEXT:    st.global.u64 [%rd4], %rd6;
 ; PTX-NEXT:    add.s32 %r3, %r1, %r2;
+; PTX-NEXT:    mov.u64 %rd1, escape;
 ; PTX-NEXT:    { // callseq 3, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0+0], %rd5;
+; PTX-NEXT:    st.param.b64 [param0+0], %rd6;
 ; PTX-NEXT:    .param .b32 retval0;
-; PTX-NEXT:    call.uni (retval0),
-; PTX-NEXT:    escape,
+; PTX-NEXT:    prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
+; PTX-NEXT:    call (retval0),
+; PTX-NEXT:    %rd1,
 ; PTX-NEXT:    (
 ; PTX-NEXT:    param0
-; PTX-NEXT:    );
+; PTX-NEXT:    )
+; PTX-NEXT:    , prototype_3;
 ; PTX-NEXT:    ld.param.b32 %r4, [retval0+0];
 ; PTX-NEXT:    } // callseq 3
 ; PTX-NEXT:    st.param.b32 [func_retval0+0], %r3;


        


More information about the llvm-commits mailing list