[llvm] [NVPTX] fixup support for over-aligned parameters (PR #92457)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri May 17 12:19:02 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/92457

>From 58cc333c6d1e85c4cd6beef8ac4d5ca94e4a733f Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 13 May 2024 21:14:43 +0000
Subject: [PATCH 1/2] [NVPTX] fixup support for over-aligned parameters

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp   |  14 ++-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  33 +++---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h   |   3 +
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp    |  47 +++++----
 llvm/lib/Target/NVPTX/NVPTXUtilities.h      |   5 +-
 llvm/test/CodeGen/NVPTX/param-overalign.ll  | 109 ++++++++++++++++++++
 6 files changed, 168 insertions(+), 43 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/param-overalign.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9f31b72bbceb1..dc9377df208d2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -72,6 +72,7 @@
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/Alignment.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Endian.h"
@@ -370,11 +371,10 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
         << " func_retval0";
     } else if (ShouldPassAsArray(Ty)) {
       unsigned totalsz = DL.getTypeAllocSize(Ty);
-      unsigned retAlignment = 0;
-      if (!getAlign(*F, 0, retAlignment))
-        retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
-      O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
-        << "]";
+      Align RetAlignment = TLI->getFunctionArgumentAlignment(
+          F, Ty, AttributeList::ReturnIndex, DL);
+      O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
+        << totalsz << "]";
     } else
       llvm_unreachable("Unknown return type");
   } else {
@@ -1558,6 +1558,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
     auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
                                     paramIndex](Type *Ty) -> Align {
+      if (MaybeAlign StackAlign =
+              getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
+        return StackAlign.value();
+
       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
       MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
       return std::max(TypeAlign, ParamAlign.valueOrOne());
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b03803f52b78e..1e7477cf9d60e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1434,12 +1434,11 @@ std::string NVPTXTargetLowering::getPrototype(
 
     if (!Outs[OIdx].Flags.isByVal()) {
       if (IsTypePassedAsArray(Ty)) {
-        unsigned ParamAlign = 0;
         const CallInst *CallI = cast<CallInst>(&CB);
-        // +1 because index 0 is reserved for return type alignment
-        if (!getAlign(*CallI, i + 1, ParamAlign))
-          ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
-        O << ".param .align " << ParamAlign << " .b8 ";
+        Align ParamAlign =
+            getAlign(*CallI, i + AttributeList::FirstArgIndex)
+                .value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
+        O << ".param .align " << ParamAlign.value() << " .b8 ";
         O << "_";
         O << "[" << DL.getTypeAllocSize(Ty) << "]";
         // update the index for Outs
@@ -1489,6 +1488,11 @@ std::string NVPTXTargetLowering::getPrototype(
   return Prototype;
 }
 
+Align NVPTXTargetLowering::getFunctionArgumentAlignment(
+    const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
+  return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
+}
+
 Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
                                                 unsigned Idx,
                                                 const DataLayout &DL) const {
@@ -1497,7 +1501,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
     return DL.getABITypeAlign(Ty);
   }
 
-  unsigned Alignment = 0;
   const Function *DirectCallee = CB->getCalledFunction();
 
   if (!DirectCallee) {
@@ -1507,21 +1510,16 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
     // With bitcast'd call targets, the instruction will be the call
     if (const auto *CI = dyn_cast<CallInst>(CB)) {
       // Check if we have call alignment metadata
-      if (getAlign(*CI, Idx, Alignment))
-        return Align(Alignment);
+      if (MaybeAlign StackAlign = getAlign(*CI, Idx))
+        return StackAlign.value();
     }
     DirectCallee = getMaybeBitcastedCallee(CB);
   }
 
   // Check for function alignment information if we found that the
   // ultimate target is a Function
-  if (DirectCallee) {
-    if (getAlign(*DirectCallee, Idx, Alignment))
-      return Align(Alignment);
-    // If alignment information is not available, fall back to the
-    // default function param optimized type alignment
-    return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
-  }
+  if (DirectCallee)
+    return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
 
   // Call is indirect, fall back to the ABI type alignment
   return DL.getABITypeAlign(Ty);
@@ -3195,8 +3193,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       if (VTs.empty())
         report_fatal_error("Empty parameter types are not supported");
 
-      auto VectorInfo =
-          VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty));
+      Align ArgAlign = getFunctionArgumentAlignment(
+          F, Ty, i + AttributeList::FirstArgIndex, DL);
+      auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
 
       SDValue Arg = getParamSymbol(DAG, i, PtrVT);
       int VecIdx = -1; // Index of the first element of the current vector.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index c9db10e555cef..e211286fcc556 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -462,6 +462,9 @@ class NVPTXTargetLowering : public TargetLowering {
                           MachineFunction &MF,
                           unsigned Intrinsic) const override;
 
+  Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
+                                     const DataLayout &DL) const;
+
   /// getFunctionParamOptimizedAlign - since function arguments are passed via
   /// .param space, we may want to increase their alignment in a way that
   /// ensures that we can effectively vectorize their loads & stores. We can
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 35302889095f8..80896a5bc4fd9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -19,11 +19,13 @@
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/Support/Alignment.h"
 #include "llvm/Support/Mutex.h"
 #include <algorithm>
 #include <cstring>
 #include <map>
 #include <mutex>
+#include <optional>
 #include <string>
 #include <vector>
 
@@ -296,37 +298,44 @@ bool isKernelFunction(const Function &F) {
   return (x == 1);
 }
 
-bool getAlign(const Function &F, unsigned index, unsigned &align) {
+MaybeAlign getAlign(const Function &F, unsigned Index) {
+  // First check the alignstack metadata
+  if (MaybeAlign AlignStack =
+          F.getAttributes().getAttributes(Index).getStackAlignment())
+    return AlignStack;
+
+  // If that is missing, check the legacy nvvm metadata
   std::vector<unsigned> Vs;
   bool retval = findAllNVVMAnnotation(&F, "align", Vs);
   if (!retval)
-    return false;
-  for (unsigned v : Vs) {
-    if ((v >> 16) == index) {
-      align = v & 0xFFFF;
-      return true;
-    }
-  }
-  return false;
+    return std::nullopt;
+  for (unsigned V : Vs)
+    if ((V >> 16) == Index)
+      return Align(V & 0xFFFF);
+
+  return std::nullopt;
 }
 
-bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
+MaybeAlign getAlign(const CallInst &I, unsigned Index) {
+  // First check the alignstack metadata
+  if (MaybeAlign AlignStack =
+          I.getAttributes().getAttributes(Index).getStackAlignment())
+    return AlignStack;
+
+  // If that is missing, check the legacy nvvm metadata
   if (MDNode *alignNode = I.getMetadata("callalign")) {
     for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
       if (const ConstantInt *CI =
               mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
-        unsigned v = CI->getZExtValue();
-        if ((v >> 16) == index) {
-          align = v & 0xFFFF;
-          return true;
-        }
-        if ((v >> 16) > index) {
-          return false;
-        }
+        unsigned V = CI->getZExtValue();
+        if ((V >> 16) == Index)
+          return Align(V & 0xFFFF);
+        if ((V >> 16) > Index)
+          return std::nullopt;
       }
     }
   }
-  return false;
+  return std::nullopt;
 }
 
 Function *getMaybeBitcastedCallee(const CallBase *CB) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 449973bb53de7..2872db9fa2131 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -18,6 +18,7 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Value.h"
+#include "llvm/Support/Alignment.h"
 #include <cstdarg>
 #include <set>
 #include <string>
@@ -60,8 +61,8 @@ bool getMinCTASm(const Function &, unsigned &);
 bool getMaxNReg(const Function &, unsigned &);
 bool isKernelFunction(const Function &);
 
-bool getAlign(const Function &, unsigned index, unsigned &);
-bool getAlign(const CallInst &, unsigned index, unsigned &);
+MaybeAlign getAlign(const Function &, unsigned);
+MaybeAlign getAlign(const CallInst &, unsigned);
 Function *getMaybeBitcastedCallee(const CallBase *CB);
 
 // PTX ABI requires all scalar argument/return values to have
diff --git a/llvm/test/CodeGen/NVPTX/param-overalign.ll b/llvm/test/CodeGen/NVPTX/param-overalign.ll
new file mode 100644
index 0000000000000..63e706982f394
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/param-overalign.ll
@@ -0,0 +1,109 @@
+; RUN: llc < %s -march=nvptx | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -verify-machineinstrs | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.float2 = type { float, float }
+
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) callee_md
+; CHECK-NEXT: (
+; CHECK-NEXT:         .param .align 8 .b8 callee_md_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: ;
+
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) callee
+; CHECK-NEXT: (
+; CHECK-NEXT:         .param .align 8 .b8 callee_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: ;
+
+define float @caller_md(float %a, float %b) {
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) caller_md(
+; CHECK-NEXT:         .param .b32 caller_md_param_0,
+; CHECK-NEXT:         .param .b32 caller_md_param_1
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK:         ld.param.f32 %f1, [caller_md_param_0];
+; CHECK-NEXT:    ld.param.f32 %f2, [caller_md_param_1];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .param .align 8 .b8 param0[8];
+; CHECK-NEXT:    st.param.v2.f32 [param0+0], {%f1, %f2};
+; CHECK-NEXT:    .param .b32 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee_md,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    ld.param.f32 %f3, [retval0+0];
+; CHECK-NEXT:    }
+; CHECK-NEXT:    st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT:    ret;
+  %s1 = insertvalue %struct.float2 poison, float %a, 0
+  %s2 = insertvalue %struct.float2 %s1, float %b, 1
+  %r = call float @callee_md(%struct.float2 %s2)
+  ret float %r
+}
+
+define float @callee_md(%struct.float2 %a) {
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) callee_md(
+; CHECK-NEXT:         .param .align 8 .b8 callee_md_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK:         ld.param.v2.f32 {%f1, %f2}, [callee_md_param_0];
+; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT:    st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT:    ret;
+  %v0 = extractvalue %struct.float2 %a, 0
+  %v1 = extractvalue %struct.float2 %a, 1
+  %2 = fadd float %v0, %v1
+  ret float %2
+}
+
+define float @caller(float %a, float %b) {
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) caller(
+; CHECK-NEXT:         .param .b32 caller_param_0,
+; CHECK-NEXT:         .param .b32 caller_param_1
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK:         ld.param.f32 %f1, [caller_param_0];
+; CHECK-NEXT:    ld.param.f32 %f2, [caller_param_1];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .param .align 8 .b8 param0[8];
+; CHECK-NEXT:    st.param.v2.f32 [param0+0], {%f1, %f2};
+; CHECK-NEXT:    .param .b32 retval0;
+; CHECK-NEXT:    call.uni (retval0),
+; CHECK-NEXT:    callee,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    ld.param.f32 %f3, [retval0+0];
+; CHECK-NEXT:    }
+; CHECK-NEXT:    st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT:    ret;
+  %s1 = insertvalue %struct.float2 poison, float %a, 0
+  %s2 = insertvalue %struct.float2 %s1, float %b, 1
+  %r = call float @callee(%struct.float2 %s2)
+  ret float %r
+}
+
+define float @callee(%struct.float2 alignstack(8) %a ) {
+; CHECK-LABEL: .visible .func  (.param .b32 func_retval0) callee(
+; CHECK-NEXT:         .param .align 8 .b8 callee_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK:         ld.param.v2.f32 {%f1, %f2}, [callee_param_0];
+; CHECK-NEXT:    add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT:    st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT:    ret;
+  %v0 = extractvalue %struct.float2 %a, 0
+  %v1 = extractvalue %struct.float2 %a, 1
+  %2 = fadd float %v0, %v1
+  ret float %2
+}
+
+!nvvm.annotations = !{!0}
+!0 = !{ptr @callee_md, !"align", i32 u0x00010008}

>From 0b60083e6ed5545d668999c2c6316711f3d50add Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 17 May 2024 19:18:26 +0000
Subject: [PATCH 2/2] address comments

---
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80896a5bc4fd9..013afe916e86c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -300,9 +300,9 @@ bool isKernelFunction(const Function &F) {
 
 MaybeAlign getAlign(const Function &F, unsigned Index) {
   // First check the alignstack metadata
-  if (MaybeAlign AlignStack =
+  if (MaybeAlign StackAlign =
           F.getAttributes().getAttributes(Index).getStackAlignment())
-    return AlignStack;
+    return StackAlign;
 
   // If that is missing, check the legacy nvvm metadata
   std::vector<unsigned> Vs;
@@ -318,9 +318,9 @@ MaybeAlign getAlign(const Function &F, unsigned Index) {
 
 MaybeAlign getAlign(const CallInst &I, unsigned Index) {
   // First check the alignstack metadata
-  if (MaybeAlign AlignStack =
+  if (MaybeAlign StackAlign =
           I.getAttributes().getAttributes(Index).getStackAlignment())
-    return AlignStack;
+    return StackAlign;
 
   // If that is missing, check the legacy nvvm metadata
   if (MDNode *alignNode = I.getMetadata("callalign")) {



More information about the llvm-commits mailing list