[llvm] [NVPTX] fixup support for over-aligned parameters (PR #92457)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Fri May 17 12:19:02 PDT 2024
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/92457
>From 58cc333c6d1e85c4cd6beef8ac4d5ca94e4a733f Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 13 May 2024 21:14:43 +0000
Subject: [PATCH 1/2] [NVPTX] fixup support for over-aligned parameters
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 14 ++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 33 +++---
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 3 +
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 47 +++++----
llvm/lib/Target/NVPTX/NVPTXUtilities.h | 5 +-
llvm/test/CodeGen/NVPTX/param-overalign.ll | 109 ++++++++++++++++++++
6 files changed, 168 insertions(+), 43 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/param-overalign.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9f31b72bbceb1..dc9377df208d2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -72,6 +72,7 @@
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Endian.h"
@@ -370,11 +371,10 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
<< " func_retval0";
} else if (ShouldPassAsArray(Ty)) {
unsigned totalsz = DL.getTypeAllocSize(Ty);
- unsigned retAlignment = 0;
- if (!getAlign(*F, 0, retAlignment))
- retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
- O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
- << "]";
+ Align RetAlignment = TLI->getFunctionArgumentAlignment(
+ F, Ty, AttributeList::ReturnIndex, DL);
+ O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
+ << totalsz << "]";
} else
llvm_unreachable("Unknown return type");
} else {
@@ -1558,6 +1558,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
paramIndex](Type *Ty) -> Align {
+ if (MaybeAlign StackAlign =
+ getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
+ return StackAlign.value();
+
Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
return std::max(TypeAlign, ParamAlign.valueOrOne());
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b03803f52b78e..1e7477cf9d60e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1434,12 +1434,11 @@ std::string NVPTXTargetLowering::getPrototype(
if (!Outs[OIdx].Flags.isByVal()) {
if (IsTypePassedAsArray(Ty)) {
- unsigned ParamAlign = 0;
const CallInst *CallI = cast<CallInst>(&CB);
- // +1 because index 0 is reserved for return type alignment
- if (!getAlign(*CallI, i + 1, ParamAlign))
- ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value();
- O << ".param .align " << ParamAlign << " .b8 ";
+ Align ParamAlign =
+ getAlign(*CallI, i + AttributeList::FirstArgIndex)
+ .value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
+ O << ".param .align " << ParamAlign.value() << " .b8 ";
O << "_";
O << "[" << DL.getTypeAllocSize(Ty) << "]";
// update the index for Outs
@@ -1489,6 +1488,11 @@ std::string NVPTXTargetLowering::getPrototype(
return Prototype;
}
+Align NVPTXTargetLowering::getFunctionArgumentAlignment(
+ const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
+ return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL));
+}
+
Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
unsigned Idx,
const DataLayout &DL) const {
@@ -1497,7 +1501,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}
- unsigned Alignment = 0;
const Function *DirectCallee = CB->getCalledFunction();
if (!DirectCallee) {
@@ -1507,21 +1510,16 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
// With bitcast'd call targets, the instruction will be the call
if (const auto *CI = dyn_cast<CallInst>(CB)) {
// Check if we have call alignment metadata
- if (getAlign(*CI, Idx, Alignment))
- return Align(Alignment);
+ if (MaybeAlign StackAlign = getAlign(*CI, Idx))
+ return StackAlign.value();
}
DirectCallee = getMaybeBitcastedCallee(CB);
}
// Check for function alignment information if we found that the
// ultimate target is a Function
- if (DirectCallee) {
- if (getAlign(*DirectCallee, Idx, Alignment))
- return Align(Alignment);
- // If alignment information is not available, fall back to the
- // default function param optimized type alignment
- return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL);
- }
+ if (DirectCallee)
+ return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL);
// Call is indirect, fall back to the ABI type alignment
return DL.getABITypeAlign(Ty);
@@ -3195,8 +3193,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (VTs.empty())
report_fatal_error("Empty parameter types are not supported");
- auto VectorInfo =
- VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty));
+ Align ArgAlign = getFunctionArgumentAlignment(
+ F, Ty, i + AttributeList::FirstArgIndex, DL);
+ auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
int VecIdx = -1; // Index of the first element of the current vector.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index c9db10e555cef..e211286fcc556 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -462,6 +462,9 @@ class NVPTXTargetLowering : public TargetLowering {
MachineFunction &MF,
unsigned Intrinsic) const override;
+ Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
+ const DataLayout &DL) const;
+
/// getFunctionParamOptimizedAlign - since function arguments are passed via
/// .param space, we may want to increase their alignment in a way that
/// ensures that we can effectively vectorize their loads & stores. We can
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 35302889095f8..80896a5bc4fd9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -19,11 +19,13 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Mutex.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <mutex>
+#include <optional>
#include <string>
#include <vector>
@@ -296,37 +298,44 @@ bool isKernelFunction(const Function &F) {
return (x == 1);
}
-bool getAlign(const Function &F, unsigned index, unsigned &align) {
+MaybeAlign getAlign(const Function &F, unsigned Index) {
+ // First check the alignstack metadata
+ if (MaybeAlign AlignStack =
+ F.getAttributes().getAttributes(Index).getStackAlignment())
+ return AlignStack;
+
+ // If that is missing, check the legacy nvvm metadata
std::vector<unsigned> Vs;
bool retval = findAllNVVMAnnotation(&F, "align", Vs);
if (!retval)
- return false;
- for (unsigned v : Vs) {
- if ((v >> 16) == index) {
- align = v & 0xFFFF;
- return true;
- }
- }
- return false;
+ return std::nullopt;
+ for (unsigned V : Vs)
+ if ((V >> 16) == Index)
+ return Align(V & 0xFFFF);
+
+ return std::nullopt;
}
-bool getAlign(const CallInst &I, unsigned index, unsigned &align) {
+MaybeAlign getAlign(const CallInst &I, unsigned Index) {
+ // First check the alignstack metadata
+ if (MaybeAlign AlignStack =
+ I.getAttributes().getAttributes(Index).getStackAlignment())
+ return AlignStack;
+
+ // If that is missing, check the legacy nvvm metadata
if (MDNode *alignNode = I.getMetadata("callalign")) {
for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) {
if (const ConstantInt *CI =
mdconst::dyn_extract<ConstantInt>(alignNode->getOperand(i))) {
- unsigned v = CI->getZExtValue();
- if ((v >> 16) == index) {
- align = v & 0xFFFF;
- return true;
- }
- if ((v >> 16) > index) {
- return false;
- }
+ unsigned V = CI->getZExtValue();
+ if ((V >> 16) == Index)
+ return Align(V & 0xFFFF);
+ if ((V >> 16) > Index)
+ return std::nullopt;
}
}
}
- return false;
+ return std::nullopt;
}
Function *getMaybeBitcastedCallee(const CallBase *CB) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 449973bb53de7..2872db9fa2131 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -18,6 +18,7 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Value.h"
+#include "llvm/Support/Alignment.h"
#include <cstdarg>
#include <set>
#include <string>
@@ -60,8 +61,8 @@ bool getMinCTASm(const Function &, unsigned &);
bool getMaxNReg(const Function &, unsigned &);
bool isKernelFunction(const Function &);
-bool getAlign(const Function &, unsigned index, unsigned &);
-bool getAlign(const CallInst &, unsigned index, unsigned &);
+MaybeAlign getAlign(const Function &, unsigned);
+MaybeAlign getAlign(const CallInst &, unsigned);
Function *getMaybeBitcastedCallee(const CallBase *CB);
// PTX ABI requires all scalar argument/return values to have
diff --git a/llvm/test/CodeGen/NVPTX/param-overalign.ll b/llvm/test/CodeGen/NVPTX/param-overalign.ll
new file mode 100644
index 0000000000000..63e706982f394
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/param-overalign.ll
@@ -0,0 +1,109 @@
+; RUN: llc < %s -march=nvptx | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -verify-machineinstrs | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+%struct.float2 = type { float, float }
+
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md
+; CHECK-NEXT: (
+; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: ;
+
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee
+; CHECK-NEXT: (
+; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: ;
+
+define float @caller_md(float %a, float %b) {
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller_md(
+; CHECK-NEXT: .param .b32 caller_md_param_0,
+; CHECK-NEXT: .param .b32 caller_md_param_1
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK: ld.param.f32 %f1, [caller_md_param_0];
+; CHECK-NEXT: ld.param.f32 %f2, [caller_md_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .param .align 8 .b8 param0[8];
+; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
+; CHECK-NEXT: .param .b32 retval0;
+; CHECK-NEXT: call.uni (retval0),
+; CHECK-NEXT: callee_md,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT: ret;
+ %s1 = insertvalue %struct.float2 poison, float %a, 0
+ %s2 = insertvalue %struct.float2 %s1, float %b, 1
+ %r = call float @callee_md(%struct.float2 %s2)
+ ret float %r
+}
+
+define float @callee_md(%struct.float2 %a) {
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md(
+; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_md_param_0];
+; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT: ret;
+ %v0 = extractvalue %struct.float2 %a, 0
+ %v1 = extractvalue %struct.float2 %a, 1
+ %2 = fadd float %v0, %v1
+ ret float %2
+}
+
+define float @caller(float %a, float %b) {
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller(
+; CHECK-NEXT: .param .b32 caller_param_0,
+; CHECK-NEXT: .param .b32 caller_param_1
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK: ld.param.f32 %f1, [caller_param_0];
+; CHECK-NEXT: ld.param.f32 %f2, [caller_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .param .align 8 .b8 param0[8];
+; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2};
+; CHECK-NEXT: .param .b32 retval0;
+; CHECK-NEXT: call.uni (retval0),
+; CHECK-NEXT: callee,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: ld.param.f32 %f3, [retval0+0];
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT: ret;
+ %s1 = insertvalue %struct.float2 poison, float %a, 0
+ %s2 = insertvalue %struct.float2 %s1, float %b, 1
+ %r = call float @callee(%struct.float2 %s2)
+ ret float %r
+}
+
+define float @callee(%struct.float2 alignstack(8) %a ) {
+; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee(
+; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8]
+; CHECK-NEXT: )
+; CHECK-NEXT: {
+
+; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_param_0];
+; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2;
+; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
+; CHECK-NEXT: ret;
+ %v0 = extractvalue %struct.float2 %a, 0
+ %v1 = extractvalue %struct.float2 %a, 1
+ %2 = fadd float %v0, %v1
+ ret float %2
+}
+
+!nvvm.annotations = !{!0}
+!0 = !{ptr @callee_md, !"align", i32 u0x00010008}
>From 0b60083e6ed5545d668999c2c6316711f3d50add Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 17 May 2024 19:18:26 +0000
Subject: [PATCH 2/2] address comments
---
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80896a5bc4fd9..013afe916e86c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -300,9 +300,9 @@ bool isKernelFunction(const Function &F) {
MaybeAlign getAlign(const Function &F, unsigned Index) {
// First check the alignstack metadata
- if (MaybeAlign AlignStack =
+ if (MaybeAlign StackAlign =
F.getAttributes().getAttributes(Index).getStackAlignment())
- return AlignStack;
+ return StackAlign;
// If that is missing, check the legacy nvvm metadata
std::vector<unsigned> Vs;
@@ -318,9 +318,9 @@ MaybeAlign getAlign(const Function &F, unsigned Index) {
MaybeAlign getAlign(const CallInst &I, unsigned Index) {
// First check the alignstack metadata
- if (MaybeAlign AlignStack =
+ if (MaybeAlign StackAlign =
I.getAttributes().getAttributes(Index).getStackAlignment())
- return AlignStack;
+ return StackAlign;
// If that is missing, check the legacy nvvm metadata
if (MDNode *alignNode = I.getMetadata("callalign")) {
More information about the llvm-commits
mailing list