[llvm] 8913b35 - [NVPTX] Enforce minumum alignment of 4 for byval parametrs in a function prototype

Andrew Savonichev via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 10 04:23:13 PST 2023


Author: Pavel Kopyl
Date: 2023-01-10T15:22:40+03:00
New Revision: 8913b35f082d4514318ffb9f8445bbb7ab726508

URL: https://github.com/llvm/llvm-project/commit/8913b35f082d4514318ffb9f8445bbb7ab726508
DIFF: https://github.com/llvm/llvm-project/commit/8913b35f082d4514318ffb9f8445bbb7ab726508.diff

LOG: [NVPTX] Enforce minumum alignment of 4 for byval parametrs in a function prototype

As a result, we have identical alignment calculation of byval
parameters for:

  - LowerCall() - getting alignment of an argument (.param)

  - emitFunctionParamList() - getting alignment of a
    parameter (.param) in a function declaration

  - getPrototype() - getting alignment of a parameter (.param) in a
    function prototypes that is used for indirect calls

This change is required to avoid ptxas error: 'Alignment of argument
does not match formal parameter'. This error happens even in cases
where it logically shouldn't.

For instance:

  .param .align 4 .b8 param0[4];
  ...
  callprototype ()_ (.param .align 2 .b8 _[4]);
  ...

Here we allocate 'param0' with alignment of 4 and it should be fine to
pass it to a function that requires minimum alignment of 2.

At least ptxas v12.0 rejects this code.

Differential Revision: https://reviews.llvm.org/D140581

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
    llvm/test/CodeGen/NVPTX/param-align.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index eab19547101a6..dbf4bf4815eaf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1612,21 +1612,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
-      Align OptimalAlign = getOptimalAlignForParam(ETy);
-
-      // Work around a bug in ptxas. When PTX code takes address of
-      // byval parameter with alignment < 4, ptxas generates code to
-      // spill argument into memory. Alas on sm_50+ ptxas generates
-      // SASS code that fails with misaligned access. To work around
-      // the problem, make sure that we align byval parameters by at
-      // least 4. Matching change must be made in LowerCall() where we
-      // prepare parameters for the call.
-      //
-      // TODO: this will need to be undone when we get to support multi-TU
-      // device-side compilation as it breaks ABI compatibility with nvcc.
-      // Hopefully ptxas bug is fixed by then.
-      if (!isKernelFunc && OptimalAlign < Align(4))
-        OptimalAlign = Align(4);
+      Align OptimalAlign =
+          isKernelFunc
+              ? getOptimalAlignForParam(ETy)
+              : TLI->getFunctionByValParamAlign(
+                    F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
+
       unsigned sz = DL.getTypeAllocSize(ETy);
       O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
       printParamName(I, paramIndex, O);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b7e81ea596b44..62066706df0d2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1414,13 +1414,10 @@ std::string NVPTXTargetLowering::getPrototype(
       continue;
     }
 
-    Align ParamByValAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
-    // Try to increase alignment. This code matches logic in LowerCall when
-    // alignment increase is performed to increase vectorization options.
     Type *ETy = Args[i].IndirectType;
-    Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
-    ParamByValAlign = std::max(ParamByValAlign, AlignCandidate);
+    Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+    Align ParamByValAlign =
+        getFunctionByValParamAlign(F, ETy, InitialAlign, DL);
 
     O << ".param .align " << ParamByValAlign.value() << " .b8 ";
     O << "_";
@@ -1560,17 +1557,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
       // so we don't need to worry whether it's naturally aligned or not.
       // See TargetLowering::LowerCallTo().
-      ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
-      // Try to increase alignment to enhance vectorization options.
-      if (const Function *DirectCallee = CB->getCalledFunction())
-        ArgAlign = std::max(
-            ArgAlign, getFunctionParamOptimizedAlign(DirectCallee, ETy, DL));
-
-      // Enforce minumum alignment of 4 to work around ptxas miscompile
-      // for sm_50+. See corresponding alignment adjustment in
-      // emitFunctionParamList() for details.
-      ArgAlign = std::max(ArgAlign, Align(4));
+      Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+      ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+                                            InitialAlign, DL);
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
@@ -4510,6 +4499,29 @@ Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
   return Align(std::max(uint64_t(16), ABITypeAlign));
 }
 
+/// Helper for computing alignment of a device function byval parameter.
+Align NVPTXTargetLowering::getFunctionByValParamAlign(
+    const Function *F, Type *ArgTy, Align InitialAlign,
+    const DataLayout &DL) const {
+  Align ArgAlign = InitialAlign;
+  // Try to increase alignment to enhance vectorization options.
+  if (F)
+    ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
+
+  // Work around a bug in ptxas. When PTX code takes address of
+  // byval parameter with alignment < 4, ptxas generates code to
+  // spill argument into memory. Alas on sm_50+ ptxas generates
+  // SASS code that fails with misaligned access. To work around
+  // the problem, make sure that we align byval parameters by at
+  // least 4.
+  // TODO: this will need to be undone when we get to support multi-TU
+  // device-side compilation as it breaks ABI compatibility with nvcc.
+  // Hopefully ptxas bug is fixed by then.
+  ArgAlign = std::max(ArgAlign, Align(4));
+
+  return ArgAlign;
+}
+
 /// isLegalAddressingMode - Return true if the addressing mode represented
 /// by AM is legal for this target, for a load/store of the specified type.
 /// Used to guide target specific optimizations, like loop strength reduction

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 78d82315fbf69..f48ec1740b0f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -461,6 +461,11 @@ class NVPTXTargetLowering : public TargetLowering {
   Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
                                        const DataLayout &DL) const;
 
+  /// Helper for computing alignment of a device function byval parameter.
+  Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
+                                   Align InitialAlign,
+                                   const DataLayout &DL) const;
+
   /// isLegalAddressingMode - Return true if the addressing mode represented
   /// by AM is legal for this target, for a load/store of the specified type
   /// Used to guide target specific optimizations, like loop strength

diff  --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
index 7a0d78c5376e9..a743b146bf057 100644
--- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -37,9 +37,9 @@ define void @boom() {
   %fp = call ptr @usefp(ptr @callee)
   ; CHECK: .param .align 4 .b8 param0[4];
   ; CHECK: st.param.v2.b16 [param0+0]
-  ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]);
+  ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]);
   call void %fp(ptr byval(%"class.complex") null)
   ret void
 }
 
-declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE()
+declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32, i32, ptr byval(%"class.complex"))

diff  --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll
index 022a7502f0f7d..f3d3003189142 100644
--- a/llvm/test/CodeGen/NVPTX/param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/param-align.ll
@@ -43,3 +43,24 @@ define ptx_device void @t5(ptr align 2 byval(i8) %x) {
   call void @t4(ptr byval(i8) %x)
   ret void
 }
+
+;;; Make sure we adjust alignment for a function prototype
+;;; in case of an inderect call.
+
+declare ptr @getfp(i32 %n)
+%struct.half2 = type { half, half }
+define ptx_device void @t6() {
+; CHECK: .func t6
+  %fp = call ptr @getfp(i32 0)
+; CHECK: prototype_2 : .callprototype ()_ (.param .align 8 .b8 _[8]);
+  call void %fp(ptr byval(double) null);
+
+  %fp2 = call ptr @getfp(i32 1)
+; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
+  call void %fp(ptr byval(%struct.half2) null);
+
+  %fp3 = call ptr @getfp(i32 2)
+; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
+  call void %fp(ptr byval(i8) null);
+  ret void
+}


        


More information about the llvm-commits mailing list