[llvm] [NVPTX] Cleanup and refactor param align computation, addressing a few minor bugs and discrepancies (PR #188588)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 2 11:47:15 PDT 2026
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/188588
>From f05c77c0dddd824712d22cec14f32647c129f511 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 23 Mar 2026 20:07:28 +0000
Subject: [PATCH 1/5] fix align bugs
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 3 +-
llvm/test/CodeGen/NVPTX/param-overalign.ll | 7 ++--
llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll | 38 +++++++++++++++++++
3 files changed, 43 insertions(+), 5 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index af5991ebf2c8f..f5542ad7cbfa4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4190,7 +4190,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
LLVMContext &Ctx = *DAG.getContext();
const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
- const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
+ const auto RetAlign =
+ getFunctionArgumentAlignment(&F, RetTy, AttributeList::ReturnIndex, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
diff --git a/llvm/test/CodeGen/NVPTX/param-overalign.ll b/llvm/test/CodeGen/NVPTX/param-overalign.ll
index 2ee749fb3b0cb..83add8a89b07c 100644
--- a/llvm/test/CodeGen/NVPTX/param-overalign.ll
+++ b/llvm/test/CodeGen/NVPTX/param-overalign.ll
@@ -106,10 +106,9 @@ define alignstack(8) %struct.float2 @aligned_return(%struct.float2 %a ) {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b32 %r1, [aligned_return_param_0];
-; CHECK-NEXT: ld.param.b32 %r2, [aligned_return_param_0+4];
-; CHECK-NEXT: st.param.b32 [func_retval0+4], %r2;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ld.param.b32 %r1, [aligned_return_param_0+4];
+; CHECK-NEXT: ld.param.b32 %r2, [aligned_return_param_0];
+; CHECK-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
; CHECK-NEXT: ret;
ret %struct.float2 %a
}
diff --git a/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll b/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll
new file mode 100644
index 0000000000000..1f5159a3f4979
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll
@@ -0,0 +1,38 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+; Verify that return value alignment is consistent between the callee
+; (LowerReturn), the declaration (printReturnValStr), and the caller
+; (LowerCall). All three should honor alignstack on the return index.
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.big = type { i32, i32, i32, i32, i32 }
+
+; alignstack(4) on the return forces align 4 everywhere: the declaration,
+; the callee stores, and the caller loads all use scalar b32 ops.
+; CHECK-LABEL: .func (.param .align 4 .b8 func_retval0[20]) internal_ret()
+; CHECK-NOT: st.param.v4
+; CHECK: st.param.b32 [func_retval0+16], 5
+; CHECK: st.param.b32 [func_retval0+12], 4
+; CHECK: st.param.b32 [func_retval0+8], 3
+; CHECK: st.param.b32 [func_retval0+4], 2
+; CHECK: st.param.b32 [func_retval0], 1
+
+define internal alignstack(4) %struct.big @internal_ret() {
+ ret %struct.big { i32 1, i32 2, i32 3, i32 4, i32 5 }
+}
+
+; The caller also reads the return value with align 4 (scalar loads).
+; CHECK-LABEL: .visible .func (.param .align 4 .b8 func_retval0[20]) caller()
+; CHECK: .param .align 4 .b8 retval0[20];
+; CHECK: call.uni (retval0), internal_ret
+; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+16];
+; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+12];
+; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+8];
+; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0+4];
+; CHECK: ld.param.b32 {{%r[0-9]+}}, [retval0];
+
+define %struct.big @caller() {
+ %r = call %struct.big @internal_ret()
+ ret %struct.big %r
+}
>From a60a07241f82cf31273ace4c98a3152d489f7c1d Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 24 Mar 2026 20:57:30 +0000
Subject: [PATCH 2/5] nfc reorg
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 15 ++-------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 33 --------------------
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 34 +++++++++++++++++++++
llvm/lib/Target/NVPTX/NVPTXUtilities.h | 6 ++++
4 files changed, 42 insertions(+), 46 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9e7a58498040e..7763596a93afb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1382,17 +1382,6 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
}
}
- auto GetOptimalAlignForParam = [&DL, F, &Arg](Type *Ty) -> Align {
- if (MaybeAlign StackAlign =
- getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
- return StackAlign.value();
-
- Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
- MaybeAlign ParamAlign =
- Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
- return std::max(TypeAlign, ParamAlign.valueOrOne());
- };
-
if (Arg.hasByValAttr()) {
// param has byVal attribute.
Type *ETy = Arg.getParamByValType();
@@ -1403,7 +1392,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// PAL.getParamAlignment
// size = typeallocsize of element type
const Align OptimalAlign =
- IsKernelFunc ? GetOptimalAlignForParam(ETy)
+ IsKernelFunc ? getOptimalAlignForParam(F, Arg, ETy, DL)
: getFunctionByValParamAlign(
F, ETy, Arg.getParamAlign().valueOrOne(), DL);
@@ -1417,7 +1406,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
// size = typeallocsize of element type
- Align OptimalAlign = GetOptimalAlignForParam(Ty);
+ Align OptimalAlign = getOptimalAlignForParam(F, Arg, Ty, DL);
O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
<< "[" << DL.getTypeAllocSize(Ty) << "]";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f5542ad7cbfa4..af1d8cdaa0f27 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1186,9 +1186,6 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
}
}
-static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
- const DataLayout &DL);
-
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs,
@@ -1293,36 +1290,6 @@ std::string NVPTXTargetLowering::getPrototype(
return Prototype;
}
-static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
- const DataLayout &DL) {
- if (!CB) {
- // CallSite is zero, fallback to ABI type alignment
- return DL.getABITypeAlign(Ty);
- }
-
- const Function *DirectCallee = CB->getCalledFunction();
-
- if (!DirectCallee) {
- // We don't have a direct function symbol, but that may be because of
- // constant cast instructions in the call.
-
- // With bitcast'd call targets, the instruction will be the call
- if (const auto *CI = dyn_cast<CallInst>(CB)) {
- // Check if we have call alignment metadata
- if (MaybeAlign StackAlign = getAlign(*CI, Idx))
- return StackAlign.value();
- }
- DirectCallee = getMaybeBitcastedCallee(CB);
- }
-
- // Check for function alignment information if we found that the
- // ultimate target is a Function
- if (DirectCallee)
- return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
-
- // Call is indirect, fall back to the ABI type alignment
- return DL.getABITypeAlign(Ty);
-}
static bool shouldConvertToIndirectCall(const CallBase *CB,
const GlobalAddressSDNode *Func) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 6cf808bc9c858..10630f8b6bc43 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -14,6 +14,7 @@
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
#include "NVVMProperties.h"
+#include "llvm/IR/Argument.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/Support/Alignment.h"
@@ -78,6 +79,39 @@ Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
return ArgAlign;
}
+Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
+ const DataLayout &DL) {
+ if (MaybeAlign StackAlign =
+ getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
+ return StackAlign.value();
+
+ Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
+ MaybeAlign ParamAlign =
+ Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
+ return std::max(TypeAlign, ParamAlign.valueOrOne());
+}
+
+Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
+ const DataLayout &DL) {
+ if (!CB)
+ return DL.getABITypeAlign(Ty);
+
+ const Function *DirectCallee = CB->getCalledFunction();
+
+ if (!DirectCallee) {
+ if (const auto *CI = dyn_cast<CallInst>(CB)) {
+ if (MaybeAlign StackAlign = getAlign(*CI, Idx))
+ return StackAlign.value();
+ }
+ DirectCallee = getMaybeBitcastedCallee(CB);
+ }
+
+ if (DirectCallee)
+ return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
+
+ return DL.getABITypeAlign(Ty);
+}
+
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
const auto &ST =
*static_cast<const NVPTXTargetMachine &>(TM).getSubtargetImpl();
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index b6e18bf998897..7d07206f36c11 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -47,6 +47,12 @@ Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
Align InitialAlign, const DataLayout &DL);
+Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
+ const DataLayout &DL);
+
+Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
+ const DataLayout &DL);
+
// PTX ABI requires all scalar argument/return values to have
// bit-size as a power of two of at least 32 bits.
inline unsigned promoteScalarArgumentSize(unsigned size) {
>From 882101971964875e1253f79ac4daf92b97aaecc4 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 25 Mar 2026 20:08:25 +0000
Subject: [PATCH 3/5] refactor
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 13 +--
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 23 ++---
.../Target/NVPTX/NVPTXSetByValParamAlign.cpp | 2 +-
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 89 +++++++++----------
llvm/lib/Target/NVPTX/NVPTXUtilities.h | 34 ++++---
llvm/lib/Target/NVPTX/NVVMProperties.cpp | 2 +-
llvm/lib/Target/NVPTX/NVVMProperties.h | 4 +-
.../CodeGen/NVPTX/nvvm-annotations-D120129.ll | 2 +-
8 files changed, 86 insertions(+), 83 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 7763596a93afb..2d79bd4924e20 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -276,7 +276,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
if (shouldPassAsArray(Ty)) {
const unsigned TotalSize = DL.getTypeAllocSize(Ty);
const Align RetAlignment =
- getFunctionArgumentAlignment(F, Ty, AttributeList::ReturnIndex, DL);
+ getParamAlign(F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
<< TotalSize << "]";
} else if (Ty->isFloatingPointTy()) {
@@ -1392,9 +1392,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// PAL.getParamAlignment
// size = typeallocsize of element type
const Align OptimalAlign =
- IsKernelFunc ? getOptimalAlignForParam(F, Arg, ETy, DL)
- : getFunctionByValParamAlign(
- F, ETy, Arg.getParamAlign().valueOrOne(), DL);
+ IsKernelFunc
+ ? getParamAlign(F, ETy,
+ Arg.getArgNo() + AttributeList::FirstArgIndex, DL)
+ : getDeviceByValParamAlign(F, ETy,
+ Arg.getParamAlign().valueOrOne(), DL);
O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
<< "[" << DL.getTypeAllocSize(ETy) << "]";
@@ -1406,7 +1408,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
// size = typeallocsize of element type
- Align OptimalAlign = getOptimalAlignForParam(F, Arg, Ty, DL);
+ Align OptimalAlign = getParamAlign(
+ F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
<< "[" << DL.getTypeAllocSize(Ty) << "]";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index af1d8cdaa0f27..a8d8f8cd351cb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1202,7 +1202,8 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if (shouldPassAsArray(RetTy)) {
- const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
+ const Align RetAlign =
+ getParamAlign(&CB, RetTy, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlign.value() << " .b8 _["
<< DL.getTypeAllocSize(RetTy) << "]";
} else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
@@ -1250,14 +1251,14 @@ std::string NVPTXTargetLowering::getPrototype(
Type *ETy = Args[I].IndirectType;
Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
Align ParamByValAlign =
- getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+ getDeviceByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
O << ".param .align " << ParamByValAlign.value() << " .b8 _["
<< ArgOuts[0].Flags.getByValSize() << "]";
} else {
if (shouldPassAsArray(Ty)) {
Align ParamAlign =
- getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+ getParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 _["
<< DL.getTypeAllocSize(Ty) << "]";
continue;
@@ -1466,10 +1467,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// so we don't need to worry whether it's naturally aligned or not.
// See TargetLowering::LowerCallTo().
const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
- return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
- InitialAlign, DL);
+ return getDeviceByValParamAlign(CB->getCalledFunction(), ETy,
+ InitialAlign, DL);
}
- return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
+ return getParamAlign(CB, Arg.Ty, ArgI + AttributeList::FirstArgIndex, DL);
}();
const unsigned TySize = DL.getTypeAllocSize(ETy);
@@ -1603,7 +1604,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
if (shouldPassAsArray(RetTy)) {
- const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ const Align RetAlign =
+ getParamAlign(CB, RetTy, AttributeList::ReturnIndex, DL);
MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
} else {
MakeDeclareScalarParam(RetSymbol, ResultSize);
@@ -1682,7 +1684,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
assert(VTs.size() == Ins.size() && "Bad value decomposition");
- const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ const Align RetAlign =
+ getParamAlign(CB, RetTy, AttributeList::ReturnIndex, DL);
const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
@@ -4100,7 +4103,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(VTs.size() == ArgIns.size() && "Size mismatch");
assert(VTs.size() == Offsets.size() && "Size mismatch");
- const Align ArgAlign = getFunctionArgumentAlignment(
+ const Align ArgAlign = getParamAlign(
&F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
unsigned I = 0;
@@ -4158,7 +4161,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
const auto RetAlign =
- getFunctionArgumentAlignment(&F, RetTy, AttributeList::ReturnIndex, DL);
+ getParamAlign(&F, RetTy, AttributeList::ReturnIndex, DL);
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
// 32-bits are sign extended or zero extended, depending on whether
diff --git a/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp b/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
index 214078c1967ab..45d4965ed2c9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
@@ -64,7 +64,7 @@ static Align setByValParamAlign(Argument *Arg) {
Type *ByValType = Arg->getParamByValType();
const DataLayout &DL = F->getDataLayout();
- const Align OptimizedAlign = getFunctionParamOptimizedAlign(F, ByValType, DL);
+ const Align OptimizedAlign = getPromotedParamTypeAlign(F, ByValType, DL);
const Align CurrentAlign = Arg->getParamAlign().valueOrOne();
if (CurrentAlign >= OptimizedAlign)
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 10630f8b6bc43..56e71e476b3ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -14,7 +14,7 @@
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
#include "NVVMProperties.h"
-#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/Support/Alignment.h"
@@ -33,8 +33,8 @@ Function *getMaybeBitcastedCallee(const CallBase *CB) {
return dyn_cast<Function>(CB->getCalledOperand()->stripPointerCasts());
}
-Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
- const DataLayout &DL) {
+Align getPromotedParamTypeAlign(const Function *F, Type *ArgTy,
+ const DataLayout &DL) {
// Capping the alignment to 128 bytes as that is the maximum alignment
// supported by PTX.
const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
@@ -42,27 +42,21 @@ Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
// If a function has linkage different from internal or private, we
// must use default ABI alignment as external users rely on it. Same
// for a function that may be called from a function pointer.
- if (!F || !F->hasLocalLinkage() ||
- F->hasAddressTaken(/*Users=*/nullptr,
- /*IgnoreCallbackUses=*/false,
- /*IgnoreAssumeLikeCalls=*/true,
- /*IgnoreLLVMUsed=*/true))
- return ABITypeAlign;
-
- assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
- return std::max(Align(16), ABITypeAlign);
+ const bool MayOptimizeAlign =
+ F && F->hasLocalLinkage() &&
+ !F->hasAddressTaken(/*Users=*/nullptr,
+ /*IgnoreCallbackUses=*/false,
+ /*IgnoreAssumeLikeCalls=*/true,
+ /*IgnoreLLVMUsed=*/true);
+ assert(!(MayOptimizeAlign && isKernelFunction(*F)) &&
+ "Expect kernels to have non-local linkage");
+ const Align OptimizedAlign = MayOptimizeAlign ? Align(16) : Align(1);
+ return std::max(OptimizedAlign, ABITypeAlign);
}
-Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
- const DataLayout &DL) {
- return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
-}
-
-Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
- Align InitialAlign, const DataLayout &DL) {
- Align ArgAlign = InitialAlign;
- if (F)
- ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
+Align getDeviceByValParamAlign(const Function *F, Type *ArgTy,
+ Align InitialAlign, const DataLayout &DL) {
+ const Align OptimizedAlign = getPromotedParamTypeAlign(F, ArgTy, DL);
// Old ptx versions have a bug. When PTX code takes address of
// byval parameter with alignment < 4, ptxas generates code to
@@ -73,43 +67,40 @@ Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
// ptxas > 9.0.
// TODO: remove this after verifying the bug is not reproduced
// on non-deprecated ptxas versions.
- if (ForceMinByValParamAlign)
- ArgAlign = std::max(ArgAlign, Align(4));
+ const bool ShouldForceMinAlign =
+ ForceMinByValParamAlign && (!F || !isKernelFunction(*F));
+ const Align AlignFloor = ShouldForceMinAlign ? Align(4) : Align(1);
- return ArgAlign;
+ return std::max({InitialAlign, OptimizedAlign, AlignFloor});
}
-Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
- const DataLayout &DL) {
- if (MaybeAlign StackAlign =
- getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
- return StackAlign.value();
-
- Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
- MaybeAlign ParamAlign =
- Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
- return std::max(TypeAlign, ParamAlign.valueOrOne());
+Align getParamAlign(const Function *F, Type *Ty, unsigned AttrIdx,
+ const DataLayout &DL) {
+ if (F)
+ if (MaybeAlign StackAlign = getStackAlign(*F, AttrIdx))
+ return StackAlign.value();
+
+ Align TypeAlign = getPromotedParamTypeAlign(F, Ty, DL);
+ if (F && AttrIdx >= AttributeList::FirstArgIndex) {
+ unsigned ArgNo = AttrIdx - AttributeList::FirstArgIndex;
+ if (F->getAttributes().hasParamAttr(ArgNo, Attribute::ByVal))
+ return std::max(TypeAlign, F->getParamAlign(ArgNo).valueOrOne());
+ }
+ return TypeAlign;
}
-Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
- const DataLayout &DL) {
- if (!CB)
- return DL.getABITypeAlign(Ty);
+Align getParamAlign(const CallBase *CB, Type *Ty, unsigned Idx,
+ const DataLayout &DL) {
+ const Function *DirectCallee = CB ? CB->getCalledFunction() : nullptr;
- const Function *DirectCallee = CB->getCalledFunction();
+ if (!DirectCallee && CB) {
+ if (MaybeAlign StackAlign = getStackAlign(*CB, Idx))
+ return StackAlign.value();
- if (!DirectCallee) {
- if (const auto *CI = dyn_cast<CallInst>(CB)) {
- if (MaybeAlign StackAlign = getAlign(*CI, Idx))
- return StackAlign.value();
- }
DirectCallee = getMaybeBitcastedCallee(CB);
}
- if (DirectCallee)
- return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
-
- return DL.getABITypeAlign(Ty);
+ return getParamAlign(DirectCallee, Ty, Idx, DL);
}
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 7d07206f36c11..7e1eca7284bb9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -38,20 +38,26 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
/// function has internal or private linkage as for other linkage types callers
/// may already rely on default alignment. To allow using 128-bit vectorized
/// loads/stores, this function ensures that alignment is 16 or greater.
-Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
- const DataLayout &DL);
-
-Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
- const DataLayout &DL);
-
-Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
- Align InitialAlign, const DataLayout &DL);
-
-Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
- const DataLayout &DL);
-
-Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
- const DataLayout &DL);
+Align getPromotedParamTypeAlign(const Function *F, Type *ArgTy,
+ const DataLayout &DL);
+
+Align getDeviceByValParamAlign(const Function *F, Type *ArgTy,
+ Align InitialAlign, const DataLayout &DL);
+
+/// Get the alignment for a function parameter or return value.
+/// \p AttrIdx is the AttributeList index (e.g. FirstArgIndex + argNo, or
+/// ReturnIndex for return values). Checks for an explicit alignment attribute,
+/// then falls back to getPromotedParamTypeAlign, incorporating byval param
+/// alignment when applicable.
+Align getParamAlign(const Function *F, Type *Ty, unsigned AttrIdx,
+ const DataLayout &DL);
+
+/// Get the alignment for a call-site argument or return value. Resolves the
+/// callee and delegates to the Function overload of getParamAlign. For
+/// indirect calls with no resolvable callee, falls back to
+/// getPromotedParamTypeAlign.
+Align getParamAlign(const CallBase *CB, Type *Ty, unsigned AttrIdx,
+ const DataLayout &DL);
// PTX ABI requires all scalar argument/return values to have
// bit-size as a power of two of at least 32 bits.
diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.cpp b/llvm/lib/Target/NVPTX/NVVMProperties.cpp
index d68c5aaf4fe5f..012d0863c183d 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.cpp
@@ -317,7 +317,7 @@ bool isParamGridConstant(const Argument &Arg) {
return Arg.hasAttribute(NVVMAttr::GridConstant);
}
-MaybeAlign getAlign(const CallInst &I, unsigned Index) {
+MaybeAlign getStackAlign(const CallBase &I, unsigned Index) {
// First check the alignstack metadata.
if (MaybeAlign StackAlign =
I.getAttributes().getAttributes(Index).getStackAlignment())
diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.h b/llvm/lib/Target/NVPTX/NVVMProperties.h
index 6ccd6f8a20075..601b4827bbc8d 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.h
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.h
@@ -59,10 +59,10 @@ bool hasBlocksAreClusters(const Function &);
bool isParamGridConstant(const Argument &);
-inline MaybeAlign getAlign(const Function &F, unsigned Index) {
+inline MaybeAlign getStackAlign(const Function &F, unsigned Index) {
return F.getAttributes().getAttributes(Index).getStackAlignment();
}
-MaybeAlign getAlign(const CallInst &, unsigned);
+MaybeAlign getStackAlign(const CallBase &, unsigned);
} // namespace llvm
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll b/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
index 0b8d247b0bca6..bbc83cfc3b7c8 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
@@ -1,7 +1,7 @@
; RUN: llc < %s -mtriple=nvptx64-unknown-unknown | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64-unknown-unknown | %ptxas-verify %}
;
-; NVPTXTargetLowering::getFunctionParamOptimizedAlign, which was introduces in
+; NVPTXTargetLowering::getPromotedParamTypeAlign, which was introduces in
; D120129, contained a poorly designed assertion checking that a function with
; internal or private linkage is not a kernel. It relied on invariants that
; were not actually guaranteed, and that resulted in compiler crash with some
>From 07f8570850e8c34d431e0400c1a3288ecf5acb7b Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 30 Mar 2026 16:53:37 +0000
Subject: [PATCH 4/5] fixup
---
llvm/lib/Target/NVPTX/NVVMProperties.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.h b/llvm/lib/Target/NVPTX/NVVMProperties.h
index 601b4827bbc8d..7187b18d3cbf7 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.h
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.h
@@ -24,7 +24,7 @@
namespace llvm {
class Argument;
-class CallInst;
+class CallBase;
class GlobalVariable;
class Module;
class Value;
>From b45b08b3306f393d6d912c868662aeeb1af93b7c Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 2 Apr 2026 18:46:55 +0000
Subject: [PATCH 5/5] clang-fomrat
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a8d8f8cd351cb..f50ae5db458e7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1291,7 +1291,6 @@ std::string NVPTXTargetLowering::getPrototype(
return Prototype;
}
-
static bool shouldConvertToIndirectCall(const CallBase *CB,
const GlobalAddressSDNode *Func) {
if (!Func)
More information about the llvm-commits
mailing list