[llvm] [NVPTX] Cleanup and refactor param align computation, addressing a few minor bugs and discrepancies (PR #188588)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 2 11:47:15 PDT 2026


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/188588

>From f05c77c0dddd824712d22cec14f32647c129f511 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 23 Mar 2026 20:07:28 +0000
Subject: [PATCH 1/5] fix align bugs

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  3 +-
 llvm/test/CodeGen/NVPTX/param-overalign.ll    |  7 ++--
 llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll | 38 +++++++++++++++++++
 3 files changed, 43 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index af5991ebf2c8f..f5542ad7cbfa4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4190,7 +4190,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   LLVMContext &Ctx = *DAG.getContext();
 
   const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
-  const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
+  const auto RetAlign =
+      getFunctionArgumentAlignment(&F, RetTy, AttributeList::ReturnIndex, DL);
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
diff --git a/llvm/test/CodeGen/NVPTX/param-overalign.ll b/llvm/test/CodeGen/NVPTX/param-overalign.ll
index 2ee749fb3b0cb..83add8a89b07c 100644
--- a/llvm/test/CodeGen/NVPTX/param-overalign.ll
+++ b/llvm/test/CodeGen/NVPTX/param-overalign.ll
@@ -106,10 +106,9 @@ define alignstack(8) %struct.float2 @aligned_return(%struct.float2 %a ) {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b32 %r1, [aligned_return_param_0];
-; CHECK-NEXT:    ld.param.b32 %r2, [aligned_return_param_0+4];
-; CHECK-NEXT:    st.param.b32 [func_retval0+4], %r2;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ld.param.b32 %r1, [aligned_return_param_0+4];
+; CHECK-NEXT:    ld.param.b32 %r2, [aligned_return_param_0];
+; CHECK-NEXT:    st.param.v2.b32 [func_retval0], {%r2, %r1};
 ; CHECK-NEXT:    ret;
   ret %struct.float2 %a
 }
diff --git a/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll b/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll
new file mode 100644
index 0000000000000..1f5159a3f4979
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ret-align-mismatch.ll
@@ -0,0 +1,38 @@
+; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
+
+; Verify that return value alignment is consistent between the callee
+; (LowerReturn), the declaration (printReturnValStr), and the caller
+; (LowerCall). All three should honor alignstack on the return index.
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.big = type { i32, i32, i32, i32, i32 }
+
+; alignstack(4) on the return forces align 4 everywhere: the declaration,
+; the callee stores, and the caller loads all use scalar b32 ops.
+; CHECK-LABEL: .func (.param .align 4 .b8 func_retval0[20]) internal_ret()
+; CHECK-NOT:     st.param.v4
+; CHECK:         st.param.b32 [func_retval0+16], 5
+; CHECK:         st.param.b32 [func_retval0+12], 4
+; CHECK:         st.param.b32 [func_retval0+8], 3
+; CHECK:         st.param.b32 [func_retval0+4], 2
+; CHECK:         st.param.b32 [func_retval0], 1
+
+define internal alignstack(4) %struct.big @internal_ret() {
+  ret %struct.big { i32 1, i32 2, i32 3, i32 4, i32 5 }
+}
+
+; The caller also reads the return value with align 4 (scalar loads).
+; CHECK-LABEL: .visible .func (.param .align 4 .b8 func_retval0[20]) caller()
+; CHECK:         .param .align 4 .b8 retval0[20];
+; CHECK:         call.uni (retval0), internal_ret
+; CHECK:         ld.param.b32 {{%r[0-9]+}}, [retval0+16];
+; CHECK:         ld.param.b32 {{%r[0-9]+}}, [retval0+12];
+; CHECK:         ld.param.b32 {{%r[0-9]+}}, [retval0+8];
+; CHECK:         ld.param.b32 {{%r[0-9]+}}, [retval0+4];
+; CHECK:         ld.param.b32 {{%r[0-9]+}}, [retval0];
+
+define %struct.big @caller() {
+  %r = call %struct.big @internal_ret()
+  ret %struct.big %r
+}

>From a60a07241f82cf31273ace4c98a3152d489f7c1d Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Tue, 24 Mar 2026 20:57:30 +0000
Subject: [PATCH 2/5] nfc reorg

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp   | 15 ++-------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 33 --------------------
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp    | 34 +++++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXUtilities.h      |  6 ++++
 4 files changed, 42 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9e7a58498040e..7763596a93afb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1382,17 +1382,6 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       }
     }
 
-    auto GetOptimalAlignForParam = [&DL, F, &Arg](Type *Ty) -> Align {
-      if (MaybeAlign StackAlign =
-              getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
-        return StackAlign.value();
-
-      Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
-      MaybeAlign ParamAlign =
-          Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
-      return std::max(TypeAlign, ParamAlign.valueOrOne());
-    };
-
     if (Arg.hasByValAttr()) {
       // param has byVal attribute.
       Type *ETy = Arg.getParamByValType();
@@ -1403,7 +1392,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
       const Align OptimalAlign =
-          IsKernelFunc ? GetOptimalAlignForParam(ETy)
+          IsKernelFunc ? getOptimalAlignForParam(F, Arg, ETy, DL)
                        : getFunctionByValParamAlign(
                              F, ETy, Arg.getParamAlign().valueOrOne(), DL);
 
@@ -1417,7 +1406,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
-      Align OptimalAlign = GetOptimalAlignForParam(Ty);
+      Align OptimalAlign = getOptimalAlignForParam(F, Arg, Ty, DL);
 
       O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
         << "[" << DL.getTypeAllocSize(Ty) << "]";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f5542ad7cbfa4..af1d8cdaa0f27 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1186,9 +1186,6 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
   }
 }
 
-static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
-                                  const DataLayout &DL);
-
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs,
@@ -1293,36 +1290,6 @@ std::string NVPTXTargetLowering::getPrototype(
   return Prototype;
 }
 
-static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
-                                  const DataLayout &DL) {
-  if (!CB) {
-    // CallSite is zero, fallback to ABI type alignment
-    return DL.getABITypeAlign(Ty);
-  }
-
-  const Function *DirectCallee = CB->getCalledFunction();
-
-  if (!DirectCallee) {
-    // We don't have a direct function symbol, but that may be because of
-    // constant cast instructions in the call.
-
-    // With bitcast'd call targets, the instruction will be the call
-    if (const auto *CI = dyn_cast<CallInst>(CB)) {
-      // Check if we have call alignment metadata
-      if (MaybeAlign StackAlign = getAlign(*CI, Idx))
-        return StackAlign.value();
-    }
-    DirectCallee = getMaybeBitcastedCallee(CB);
-  }
-
-  // Check for function alignment information if we found that the
-  // ultimate target is a Function
-  if (DirectCallee)
-    return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
-
-  // Call is indirect, fall back to the ABI type alignment
-  return DL.getABITypeAlign(Ty);
-}
 
 static bool shouldConvertToIndirectCall(const CallBase *CB,
                                         const GlobalAddressSDNode *Func) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 6cf808bc9c858..10630f8b6bc43 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -14,6 +14,7 @@
 #include "NVPTX.h"
 #include "NVPTXTargetMachine.h"
 #include "NVVMProperties.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
 #include "llvm/Support/Alignment.h"
@@ -78,6 +79,39 @@ Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
   return ArgAlign;
 }
 
+Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
+                              const DataLayout &DL) {
+  if (MaybeAlign StackAlign =
+          getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
+    return StackAlign.value();
+
+  Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
+  MaybeAlign ParamAlign =
+      Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
+  return std::max(TypeAlign, ParamAlign.valueOrOne());
+}
+
+Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
+                           const DataLayout &DL) {
+  if (!CB)
+    return DL.getABITypeAlign(Ty);
+
+  const Function *DirectCallee = CB->getCalledFunction();
+
+  if (!DirectCallee) {
+    if (const auto *CI = dyn_cast<CallInst>(CB)) {
+      if (MaybeAlign StackAlign = getAlign(*CI, Idx))
+        return StackAlign.value();
+    }
+    DirectCallee = getMaybeBitcastedCallee(CB);
+  }
+
+  if (DirectCallee)
+    return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
+
+  return DL.getABITypeAlign(Ty);
+}
+
 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
   const auto &ST =
       *static_cast<const NVPTXTargetMachine &>(TM).getSubtargetImpl();
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index b6e18bf998897..7d07206f36c11 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -47,6 +47,12 @@ Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
 Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
                                  Align InitialAlign, const DataLayout &DL);
 
+Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
+                              const DataLayout &DL);
+
+Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
+                           const DataLayout &DL);
+
 // PTX ABI requires all scalar argument/return values to have
 // bit-size as a power of two of at least 32 bits.
 inline unsigned promoteScalarArgumentSize(unsigned size) {

>From 882101971964875e1253f79ac4daf92b97aaecc4 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 25 Mar 2026 20:08:25 +0000
Subject: [PATCH 3/5] refactor

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     | 13 +--
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 23 ++---
 .../Target/NVPTX/NVPTXSetByValParamAlign.cpp  |  2 +-
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp      | 89 +++++++++----------
 llvm/lib/Target/NVPTX/NVPTXUtilities.h        | 34 ++++---
 llvm/lib/Target/NVPTX/NVVMProperties.cpp      |  2 +-
 llvm/lib/Target/NVPTX/NVVMProperties.h        |  4 +-
 .../CodeGen/NVPTX/nvvm-annotations-D120129.ll |  2 +-
 8 files changed, 86 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 7763596a93afb..2d79bd4924e20 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -276,7 +276,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
   if (shouldPassAsArray(Ty)) {
     const unsigned TotalSize = DL.getTypeAllocSize(Ty);
     const Align RetAlignment =
-        getFunctionArgumentAlignment(F, Ty, AttributeList::ReturnIndex, DL);
+        getParamAlign(F, Ty, AttributeList::ReturnIndex, DL);
     O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
       << TotalSize << "]";
   } else if (Ty->isFloatingPointTy()) {
@@ -1392,9 +1392,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
       const Align OptimalAlign =
-          IsKernelFunc ? getOptimalAlignForParam(F, Arg, ETy, DL)
-                       : getFunctionByValParamAlign(
-                             F, ETy, Arg.getParamAlign().valueOrOne(), DL);
+          IsKernelFunc
+              ? getParamAlign(F, ETy,
+                              Arg.getArgNo() + AttributeList::FirstArgIndex, DL)
+              : getDeviceByValParamAlign(F, ETy,
+                                         Arg.getParamAlign().valueOrOne(), DL);
 
       O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
         << "[" << DL.getTypeAllocSize(ETy) << "]";
@@ -1406,7 +1408,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
       // size = typeallocsize of element type
-      Align OptimalAlign = getOptimalAlignForParam(F, Arg, Ty, DL);
+      Align OptimalAlign = getParamAlign(
+          F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
 
       O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
         << "[" << DL.getTypeAllocSize(Ty) << "]";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index af1d8cdaa0f27..a8d8f8cd351cb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1202,7 +1202,8 @@ std::string NVPTXTargetLowering::getPrototype(
   } else {
     O << "(";
     if (shouldPassAsArray(RetTy)) {
-      const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
+      const Align RetAlign =
+          getParamAlign(&CB, RetTy, AttributeList::ReturnIndex, DL);
       O << ".param .align " << RetAlign.value() << " .b8 _["
         << DL.getTypeAllocSize(RetTy) << "]";
     } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
@@ -1250,14 +1251,14 @@ std::string NVPTXTargetLowering::getPrototype(
       Type *ETy = Args[I].IndirectType;
       Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
       Align ParamByValAlign =
-          getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+          getDeviceByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
 
       O << ".param .align " << ParamByValAlign.value() << " .b8 _["
         << ArgOuts[0].Flags.getByValSize() << "]";
     } else {
       if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
-            getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+            getParamAlign(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
         O << ".param .align " << ParamAlign.value() << " .b8 _["
           << DL.getTypeAllocSize(Ty) << "]";
         continue;
@@ -1466,10 +1467,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
         // so we don't need to worry whether it's naturally aligned or not.
         // See TargetLowering::LowerCallTo().
         const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-        return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
-                                          InitialAlign, DL);
+        return getDeviceByValParamAlign(CB->getCalledFunction(), ETy,
+                                        InitialAlign, DL);
       }
-      return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
+      return getParamAlign(CB, Arg.Ty, ArgI + AttributeList::FirstArgIndex, DL);
     }();
 
     const unsigned TySize = DL.getTypeAllocSize(ETy);
@@ -1603,7 +1604,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
     const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
     if (shouldPassAsArray(RetTy)) {
-      const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+      const Align RetAlign =
+          getParamAlign(CB, RetTy, AttributeList::ReturnIndex, DL);
       MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
     } else {
       MakeDeclareScalarParam(RetSymbol, ResultSize);
@@ -1682,7 +1684,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
-    const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+    const Align RetAlign =
+        getParamAlign(CB, RetTy, AttributeList::ReturnIndex, DL);
     const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
@@ -4100,7 +4103,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       assert(VTs.size() == ArgIns.size() && "Size mismatch");
       assert(VTs.size() == Offsets.size() && "Size mismatch");
 
-      const Align ArgAlign = getFunctionArgumentAlignment(
+      const Align ArgAlign = getParamAlign(
           &F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
 
       unsigned I = 0;
@@ -4158,7 +4161,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
 
   const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
   const auto RetAlign =
-      getFunctionArgumentAlignment(&F, RetTy, AttributeList::ReturnIndex, DL);
+      getParamAlign(&F, RetTy, AttributeList::ReturnIndex, DL);
 
   // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
   // 32-bits are sign extended or zero extended, depending on whether
diff --git a/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp b/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
index 214078c1967ab..45d4965ed2c9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSetByValParamAlign.cpp
@@ -64,7 +64,7 @@ static Align setByValParamAlign(Argument *Arg) {
   Type *ByValType = Arg->getParamByValType();
   const DataLayout &DL = F->getDataLayout();
 
-  const Align OptimizedAlign = getFunctionParamOptimizedAlign(F, ByValType, DL);
+  const Align OptimizedAlign = getPromotedParamTypeAlign(F, ByValType, DL);
   const Align CurrentAlign = Arg->getParamAlign().valueOrOne();
 
   if (CurrentAlign >= OptimizedAlign)
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 10630f8b6bc43..56e71e476b3ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -14,7 +14,7 @@
 #include "NVPTX.h"
 #include "NVPTXTargetMachine.h"
 #include "NVVMProperties.h"
-#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
 #include "llvm/Support/Alignment.h"
@@ -33,8 +33,8 @@ Function *getMaybeBitcastedCallee(const CallBase *CB) {
   return dyn_cast<Function>(CB->getCalledOperand()->stripPointerCasts());
 }
 
-Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
-                                     const DataLayout &DL) {
+Align getPromotedParamTypeAlign(const Function *F, Type *ArgTy,
+                                const DataLayout &DL) {
   // Capping the alignment to 128 bytes as that is the maximum alignment
   // supported by PTX.
   const Align ABITypeAlign = std::min(Align(128), DL.getABITypeAlign(ArgTy));
@@ -42,27 +42,21 @@ Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
   // If a function has linkage different from internal or private, we
   // must use default ABI alignment as external users rely on it. Same
   // for a function that may be called from a function pointer.
-  if (!F || !F->hasLocalLinkage() ||
-      F->hasAddressTaken(/*Users=*/nullptr,
-                         /*IgnoreCallbackUses=*/false,
-                         /*IgnoreAssumeLikeCalls=*/true,
-                         /*IgnoreLLVMUsed=*/true))
-    return ABITypeAlign;
-
-  assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
-  return std::max(Align(16), ABITypeAlign);
+  const bool MayOptimizeAlign =
+      F && F->hasLocalLinkage() &&
+      !F->hasAddressTaken(/*Users=*/nullptr,
+                          /*IgnoreCallbackUses=*/false,
+                          /*IgnoreAssumeLikeCalls=*/true,
+                          /*IgnoreLLVMUsed=*/true);
+  assert(!(MayOptimizeAlign && isKernelFunction(*F)) &&
+         "Expect kernels to have non-local linkage");
+  const Align OptimizedAlign = MayOptimizeAlign ? Align(16) : Align(1);
+  return std::max(OptimizedAlign, ABITypeAlign);
 }
 
-Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
-                                   const DataLayout &DL) {
-  return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
-}
-
-Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
-                                 Align InitialAlign, const DataLayout &DL) {
-  Align ArgAlign = InitialAlign;
-  if (F)
-    ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
+Align getDeviceByValParamAlign(const Function *F, Type *ArgTy,
+                               Align InitialAlign, const DataLayout &DL) {
+  const Align OptimizedAlign = getPromotedParamTypeAlign(F, ArgTy, DL);
 
   // Old ptx versions have a bug. When PTX code takes address of
   // byval parameter with alignment < 4, ptxas generates code to
@@ -73,43 +67,40 @@ Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
   // ptxas > 9.0.
   // TODO: remove this after verifying the bug is not reproduced
   // on non-deprecated ptxas versions.
-  if (ForceMinByValParamAlign)
-    ArgAlign = std::max(ArgAlign, Align(4));
+  const bool ShouldForceMinAlign =
+      ForceMinByValParamAlign && (!F || !isKernelFunction(*F));
+  const Align AlignFloor = ShouldForceMinAlign ? Align(4) : Align(1);
 
-  return ArgAlign;
+  return std::max({InitialAlign, OptimizedAlign, AlignFloor});
 }
 
-Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
-                              const DataLayout &DL) {
-  if (MaybeAlign StackAlign =
-          getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
-    return StackAlign.value();
-
-  Align TypeAlign = getFunctionParamOptimizedAlign(F, Ty, DL);
-  MaybeAlign ParamAlign =
-      Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
-  return std::max(TypeAlign, ParamAlign.valueOrOne());
+Align getParamAlign(const Function *F, Type *Ty, unsigned AttrIdx,
+                    const DataLayout &DL) {
+  if (F)
+    if (MaybeAlign StackAlign = getStackAlign(*F, AttrIdx))
+      return StackAlign.value();
+
+  Align TypeAlign = getPromotedParamTypeAlign(F, Ty, DL);
+  if (F && AttrIdx >= AttributeList::FirstArgIndex) {
+    unsigned ArgNo = AttrIdx - AttributeList::FirstArgIndex;
+    if (F->getAttributes().hasParamAttr(ArgNo, Attribute::ByVal))
+      return std::max(TypeAlign, F->getParamAlign(ArgNo).valueOrOne());
+  }
+  return TypeAlign;
 }
 
-Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
-                           const DataLayout &DL) {
-  if (!CB)
-    return DL.getABITypeAlign(Ty);
+Align getParamAlign(const CallBase *CB, Type *Ty, unsigned Idx,
+                    const DataLayout &DL) {
+  const Function *DirectCallee = CB ? CB->getCalledFunction() : nullptr;
 
-  const Function *DirectCallee = CB->getCalledFunction();
+  if (!DirectCallee && CB) {
+    if (MaybeAlign StackAlign = getStackAlign(*CB, Idx))
+      return StackAlign.value();
 
-  if (!DirectCallee) {
-    if (const auto *CI = dyn_cast<CallInst>(CB)) {
-      if (MaybeAlign StackAlign = getAlign(*CI, Idx))
-        return StackAlign.value();
-    }
     DirectCallee = getMaybeBitcastedCallee(CB);
   }
 
-  if (DirectCallee)
-    return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
-
-  return DL.getABITypeAlign(Ty);
+  return getParamAlign(DirectCallee, Ty, Idx, DL);
 }
 
 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 7d07206f36c11..7e1eca7284bb9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -38,20 +38,26 @@ Function *getMaybeBitcastedCallee(const CallBase *CB);
 /// function has internal or private linkage as for other linkage types callers
 /// may already rely on default alignment. To allow using 128-bit vectorized
 /// loads/stores, this function ensures that alignment is 16 or greater.
-Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
-                                     const DataLayout &DL);
-
-Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
-                                   const DataLayout &DL);
-
-Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
-                                 Align InitialAlign, const DataLayout &DL);
-
-Align getOptimalAlignForParam(const Function *F, const Argument &Arg, Type *Ty,
-                              const DataLayout &DL);
-
-Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
-                           const DataLayout &DL);
+Align getPromotedParamTypeAlign(const Function *F, Type *ArgTy,
+                                const DataLayout &DL);
+
+Align getDeviceByValParamAlign(const Function *F, Type *ArgTy,
+                               Align InitialAlign, const DataLayout &DL);
+
+/// Get the alignment for a function parameter or return value.
+/// \p AttrIdx is the AttributeList index (e.g. FirstArgIndex + argNo, or
+/// ReturnIndex for return values). Checks for an explicit alignment attribute,
+/// then falls back to getPromotedParamTypeAlign, incorporating byval param
+/// alignment when applicable.
+Align getParamAlign(const Function *F, Type *Ty, unsigned AttrIdx,
+                    const DataLayout &DL);
+
+/// Get the alignment for a call-site argument or return value. Resolves the
+/// callee and delegates to the Function overload of getParamAlign. For
+/// indirect calls with no resolvable callee, falls back to
+/// getPromotedParamTypeAlign.
+Align getParamAlign(const CallBase *CB, Type *Ty, unsigned AttrIdx,
+                    const DataLayout &DL);
 
 // PTX ABI requires all scalar argument/return values to have
 // bit-size as a power of two of at least 32 bits.
diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.cpp b/llvm/lib/Target/NVPTX/NVVMProperties.cpp
index d68c5aaf4fe5f..012d0863c183d 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.cpp
@@ -317,7 +317,7 @@ bool isParamGridConstant(const Argument &Arg) {
   return Arg.hasAttribute(NVVMAttr::GridConstant);
 }
 
-MaybeAlign getAlign(const CallInst &I, unsigned Index) {
+MaybeAlign getStackAlign(const CallBase &I, unsigned Index) {
   // First check the alignstack metadata.
   if (MaybeAlign StackAlign =
           I.getAttributes().getAttributes(Index).getStackAlignment())
diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.h b/llvm/lib/Target/NVPTX/NVVMProperties.h
index 6ccd6f8a20075..601b4827bbc8d 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.h
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.h
@@ -59,10 +59,10 @@ bool hasBlocksAreClusters(const Function &);
 
 bool isParamGridConstant(const Argument &);
 
-inline MaybeAlign getAlign(const Function &F, unsigned Index) {
+inline MaybeAlign getStackAlign(const Function &F, unsigned Index) {
   return F.getAttributes().getAttributes(Index).getStackAlignment();
 }
-MaybeAlign getAlign(const CallInst &, unsigned);
+MaybeAlign getStackAlign(const CallBase &, unsigned);
 
 } // namespace llvm
 
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll b/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
index 0b8d247b0bca6..bbc83cfc3b7c8 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-annotations-D120129.ll
@@ -1,7 +1,7 @@
 ; RUN: llc < %s -mtriple=nvptx64-unknown-unknown | FileCheck %s
 ; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64-unknown-unknown | %ptxas-verify %}
 ;
-; NVPTXTargetLowering::getFunctionParamOptimizedAlign, which was introduces in
+; NVPTXTargetLowering::getPromotedParamTypeAlign, which was introduces in
 ; D120129, contained a poorly designed assertion checking that a function with
 ; internal or private linkage is not a kernel. It relied on invariants that
 ; were not actually guaranteed, and that resulted in compiler crash with some

>From 07f8570850e8c34d431e0400c1a3288ecf5acb7b Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Mon, 30 Mar 2026 16:53:37 +0000
Subject: [PATCH 4/5] fixup

---
 llvm/lib/Target/NVPTX/NVVMProperties.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMProperties.h b/llvm/lib/Target/NVPTX/NVVMProperties.h
index 601b4827bbc8d..7187b18d3cbf7 100644
--- a/llvm/lib/Target/NVPTX/NVVMProperties.h
+++ b/llvm/lib/Target/NVPTX/NVVMProperties.h
@@ -24,7 +24,7 @@
 namespace llvm {
 
 class Argument;
-class CallInst;
+class CallBase;
 class GlobalVariable;
 class Module;
 class Value;

>From b45b08b3306f393d6d912c868662aeeb1af93b7c Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 2 Apr 2026 18:46:55 +0000
Subject: [PATCH 5/5] clang-fomrat

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a8d8f8cd351cb..f50ae5db458e7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1291,7 +1291,6 @@ std::string NVPTXTargetLowering::getPrototype(
   return Prototype;
 }
 
-
 static bool shouldConvertToIndirectCall(const CallBase *CB,
                                         const GlobalAddressSDNode *Func) {
   if (!Func)



More information about the llvm-commits mailing list