[llvm] 8913b35 - [NVPTX] Enforce minumum alignment of 4 for byval parametrs in a function prototype
Andrew Savonichev via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 10 04:23:13 PST 2023
Author: Pavel Kopyl
Date: 2023-01-10T15:22:40+03:00
New Revision: 8913b35f082d4514318ffb9f8445bbb7ab726508
URL: https://github.com/llvm/llvm-project/commit/8913b35f082d4514318ffb9f8445bbb7ab726508
DIFF: https://github.com/llvm/llvm-project/commit/8913b35f082d4514318ffb9f8445bbb7ab726508.diff
LOG: [NVPTX] Enforce minumum alignment of 4 for byval parametrs in a function prototype
As a result, we have identical alignment calculation of byval
parameters for:
- LowerCall() - getting alignment of an argument (.param)
- emitFunctionParamList() - getting alignment of a
parameter (.param) in a function declaration
- getPrototype() - getting alignment of a parameter (.param) in a
function prototypes that is used for indirect calls
This change is required to avoid ptxas error: 'Alignment of argument
does not match formal parameter'. This error happens even in cases
where it logically shouldn't.
For instance:
.param .align 4 .b8 param0[4];
...
callprototype ()_ (.param .align 2 .b8 _[4]);
...
Here we allocate 'param0' with alignment of 4 and it should be fine to
pass it to a function that requires minimum alignment of 2.
At least ptxas v12.0 rejects this code.
Differential Revision: https://reviews.llvm.org/D140581
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
llvm/test/CodeGen/NVPTX/param-align.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index eab19547101a6..dbf4bf4815eaf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1612,21 +1612,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
// size = typeallocsize of element type
- Align OptimalAlign = getOptimalAlignForParam(ETy);
-
- // Work around a bug in ptxas. When PTX code takes address of
- // byval parameter with alignment < 4, ptxas generates code to
- // spill argument into memory. Alas on sm_50+ ptxas generates
- // SASS code that fails with misaligned access. To work around
- // the problem, make sure that we align byval parameters by at
- // least 4. Matching change must be made in LowerCall() where we
- // prepare parameters for the call.
- //
- // TODO: this will need to be undone when we get to support multi-TU
- // device-side compilation as it breaks ABI compatibility with nvcc.
- // Hopefully ptxas bug is fixed by then.
- if (!isKernelFunc && OptimalAlign < Align(4))
- OptimalAlign = Align(4);
+ Align OptimalAlign =
+ isKernelFunc
+ ? getOptimalAlignForParam(ETy)
+ : TLI->getFunctionByValParamAlign(
+ F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
+
unsigned sz = DL.getTypeAllocSize(ETy);
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
printParamName(I, paramIndex, O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b7e81ea596b44..62066706df0d2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1414,13 +1414,10 @@ std::string NVPTXTargetLowering::getPrototype(
continue;
}
- Align ParamByValAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
- // Try to increase alignment. This code matches logic in LowerCall when
- // alignment increase is performed to increase vectorization options.
Type *ETy = Args[i].IndirectType;
- Align AlignCandidate = getFunctionParamOptimizedAlign(F, ETy, DL);
- ParamByValAlign = std::max(ParamByValAlign, AlignCandidate);
+ Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+ Align ParamByValAlign =
+ getFunctionByValParamAlign(F, ETy, InitialAlign, DL);
O << ".param .align " << ParamByValAlign.value() << " .b8 ";
O << "_";
@@ -1560,17 +1557,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// The ByValAlign in the Outs[OIdx].Flags is always set at this point,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
- ArgAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-
- // Try to increase alignment to enhance vectorization options.
- if (const Function *DirectCallee = CB->getCalledFunction())
- ArgAlign = std::max(
- ArgAlign, getFunctionParamOptimizedAlign(DirectCallee, ETy, DL));
-
- // Enforce minumum alignment of 4 to work around ptxas miscompile
- // for sm_50+. See corresponding alignment adjustment in
- // emitFunctionParamList() for details.
- ArgAlign = std::max(ArgAlign, Align(4));
+ Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+ ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+ InitialAlign, DL);
if (IsVAArg)
VAOffset = alignTo(VAOffset, ArgAlign);
} else {
@@ -4510,6 +4499,29 @@ Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
return Align(std::max(uint64_t(16), ABITypeAlign));
}
+/// Helper for computing alignment of a device function byval parameter.
+Align NVPTXTargetLowering::getFunctionByValParamAlign(
+ const Function *F, Type *ArgTy, Align InitialAlign,
+ const DataLayout &DL) const {
+ Align ArgAlign = InitialAlign;
+ // Try to increase alignment to enhance vectorization options.
+ if (F)
+ ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
+
+ // Work around a bug in ptxas. When PTX code takes address of
+ // byval parameter with alignment < 4, ptxas generates code to
+ // spill argument into memory. Alas on sm_50+ ptxas generates
+ // SASS code that fails with misaligned access. To work around
+ // the problem, make sure that we align byval parameters by at
+ // least 4.
+ // TODO: this will need to be undone when we get to support multi-TU
+ // device-side compilation as it breaks ABI compatibility with nvcc.
+ // Hopefully ptxas bug is fixed by then.
+ ArgAlign = std::max(ArgAlign, Align(4));
+
+ return ArgAlign;
+}
+
/// isLegalAddressingMode - Return true if the addressing mode represented
/// by AM is legal for this target, for a load/store of the specified type.
/// Used to guide target specific optimizations, like loop strength reduction
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 78d82315fbf69..f48ec1740b0f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -461,6 +461,11 @@ class NVPTXTargetLowering : public TargetLowering {
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
const DataLayout &DL) const;
+ /// Helper for computing alignment of a device function byval parameter.
+ Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
+ Align InitialAlign,
+ const DataLayout &DL) const;
+
/// isLegalAddressingMode - Return true if the addressing mode represented
/// by AM is legal for this target, for a load/store of the specified type
/// Used to guide target specific optimizations, like loop strength
diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
index 7a0d78c5376e9..a743b146bf057 100644
--- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -37,9 +37,9 @@ define void @boom() {
%fp = call ptr @usefp(ptr @callee)
; CHECK: .param .align 4 .b8 param0[4];
; CHECK: st.param.v2.b16 [param0+0]
- ; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]);
+ ; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]);
call void %fp(ptr byval(%"class.complex") null)
ret void
}
-declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE()
+declare %complex_half @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32, i32, ptr byval(%"class.complex"))
diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll
index 022a7502f0f7d..f3d3003189142 100644
--- a/llvm/test/CodeGen/NVPTX/param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/param-align.ll
@@ -43,3 +43,24 @@ define ptx_device void @t5(ptr align 2 byval(i8) %x) {
call void @t4(ptr byval(i8) %x)
ret void
}
+
+;;; Make sure we adjust alignment for a function prototype
+;;; in case of an inderect call.
+
+declare ptr @getfp(i32 %n)
+%struct.half2 = type { half, half }
+define ptx_device void @t6() {
+; CHECK: .func t6
+ %fp = call ptr @getfp(i32 0)
+; CHECK: prototype_2 : .callprototype ()_ (.param .align 8 .b8 _[8]);
+ call void %fp(ptr byval(double) null);
+
+ %fp2 = call ptr @getfp(i32 1)
+; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
+ call void %fp(ptr byval(%struct.half2) null);
+
+ %fp3 = call ptr @getfp(i32 2)
+; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
+ call void %fp(ptr byval(i8) null);
+ ret void
+}
More information about the llvm-commits
mailing list