[llvm] 940fa35 - [NVPTX] Fix a segfault for bitcasted calls with byval params

Luke Drummond via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 11 07:14:00 PDT 2022


Author: Luke Drummond
Date: 2022-10-11T15:12:25+01:00
New Revision: 940fa35ece5294a115a2fdba89ef6c095d90df0f

URL: https://github.com/llvm/llvm-project/commit/940fa35ece5294a115a2fdba89ef6c095d90df0f
DIFF: https://github.com/llvm/llvm-project/commit/940fa35ece5294a115a2fdba89ef6c095d90df0f.diff

LOG: [NVPTX] Fix a segfault for bitcasted calls with byval params

`getFunctionParamOptimizedAlign` was being passed a null function
argument when getting the callee of a bitcasted function symbol. This is
because `CallBase::getCalledFunction` does not look through bitcasts.

There is already code to handle this case in
`NVPTXTargetLowering::getArgumentAlignment`, which is now hoisted into
an NVPTX util.

The alignment computation now gracefully handles computing alignment of
virtual functions with a check for null.

Added: 
    llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index dbc57eec3af29..b88a20a92387b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1429,22 +1429,8 @@ Align NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
       // Check if we have call alignment metadata
       if (getAlign(*CI, Idx, Alignment))
         return Align(Alignment);
-
-      const Value *CalleeV = CI->getCalledOperand();
-      // Ignore any bitcast instructions
-      while (isa<ConstantExpr>(CalleeV)) {
-        const ConstantExpr *CE = cast<ConstantExpr>(CalleeV);
-        if (!CE->isCast())
-          break;
-        // Look through the bitcast
-        CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0);
-      }
-
-      // We have now looked past all of the bitcasts.  Do we finally have a
-      // Function?
-      if (const auto *CalleeF = dyn_cast<Function>(CalleeV))
-        DirectCallee = CalleeF;
     }
+    DirectCallee = getMaybeBitcastedCallee(CB);
   }
 
   // Check for function alignment information if we found that the
@@ -1521,7 +1507,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
       // Try to increase alignment to enhance vectorization options.
       ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(
-                                        CB->getCalledFunction(), ETy, DL));
+                                        getMaybeBitcastedCallee(CB), ETy, DL));
 
       // Enforce minumum alignment of 4 to work around ptxas miscompile
       // for sm_50+. See corresponding alignment adjustment in
@@ -4341,7 +4327,7 @@ Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
 
   // If a function has linkage 
diff erent from internal or private, we
   // must use default ABI alignment as external users rely on it.
-  if (!F->hasLocalLinkage())
+  if (!(F && F->hasLocalLinkage()))
     return Align(ABITypeAlign);
 
   assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 4e41515b997db..90d4d35cded04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -324,4 +324,8 @@ bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
   return false;
 }
 
+Function *getMaybeBitcastedCallee(const CallBase *CB) {
+  return dyn_cast<Function>(CB->getCalledOperand()->stripPointerCasts());
+}
+
 } // namespace llvm

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 6fee57b4664ed..6e6b355f74bee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -58,6 +58,7 @@ bool isKernelFunction(const Function &);
 
 bool getAlign(const Function &, unsigned index, unsigned &);
 bool getAlign(const CallInst &, unsigned index, unsigned &);
+Function *getMaybeBitcastedCallee(const CallBase *CB);
 
 // PTX ABI requires all scalar argument/return values to have
 // bit-size as a power of two of at least 32 bits.

diff  --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
new file mode 100644
index 0000000000000..7631824339258
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -0,0 +1,28 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_50 -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_50 -verify-machineinstrs | %ptxas-verify %}
+
+; calls with a bitcasted function symbol should be fine, but in combination with
+; a byval attribute were causing a segfault during isel. This testcase was
+; reduced from a SYCL kernel using aggregate types which ended up being passed
+; `byval`
+
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+target triple = "nvptx64-nvidia-cuda"
+
+%"class.complex" = type { %"class.sycl::_V1::detail::half_impl::half", %"class.sycl::_V1::detail::half_impl::half" }
+%"class.sycl::_V1::detail::half_impl::half" = type { half }
+%complex_half = type { half, half }
+
+define weak_odr void @foo() {
+entry:
+  %call.i.i.i = tail call %"class.complex" bitcast (%complex_half ()* @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE to %"class.complex" (i32, i32, %"class.complex"*)*)(i32 0, i32 0, %"class.complex"* byval(%"class.complex") null)
+  ret void
+}
+
+declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE()
+
+; CHECK: .param .align 4 .b8 param2[4];
+; CHECK: st.param.v2.b16         [param2+0], {%h2, %h1};
+; CHECK: .param .align 2 .b8 retval0[4];
+; CHECK: call.uni (retval0),
+; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,


        


More information about the llvm-commits mailing list