[llvm] [NVPTX] Basic support for fp128 as a storage type (PR #136006)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 16 16:24:07 PDT 2025
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/136006
>From db1baa6daffd6d7b466325c14dc87715b4ca4bfd Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 15:50:52 +0000
Subject: [PATCH 1/3] [NVPTX] Basic support for fp128 as a storage type
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 74 +++++++++----------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 28 +++----
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 4 -
llvm/lib/Target/NVPTX/NVPTXUtilities.h | 9 ++-
llvm/test/CodeGen/NVPTX/fp128-storage-type.ll | 56 ++++++++++++++
.../test/CodeGen/NVPTX/global-variable-big.ll | 5 +-
6 files changed, 115 insertions(+), 61 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 65cfeadc21a3b..e0f9a1ada3bc4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}
-static bool ShouldPassAsArray(Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";
- if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
- !ShouldPassAsArray(Ty)) {
- unsigned size = 0;
- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
- size = ITy->getBitWidth();
- } else {
- assert(Ty->isFloatingPointTy() && "Floating point type expected here");
- size = Ty->getPrimitiveSizeInBits();
- }
- size = promoteScalarArgumentSize(size);
- O << ".param .b" << size << " func_retval0";
- } else if (isa<PointerType>(Ty)) {
- O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
- << " func_retval0";
- } else if (ShouldPassAsArray(Ty)) {
- unsigned totalsz = DL.getTypeAllocSize(Ty);
- Align RetAlignment = TLI->getFunctionArgumentAlignment(
+ auto PrintScalarParam = [&](unsigned Size) {
+ O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
+ };
+ if (shouldPassAsArray(Ty)) {
+ const unsigned TotalSize = DL.getTypeAllocSize(Ty);
+ const Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
- << totalsz << "]";
+ << TotalSize << "]";
+ } else if (Ty->isFloatingPointTy()) {
+ PrintScalarParam(Ty->getPrimitiveSizeInBits());
+ } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+ PrintScalarParam(ITy->getBitWidth());
+ } else if (isa<PointerType>(Ty)) {
+ PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
- (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
+ if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
+ ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
case Type::IntegerTyID: // Integers larger than 64 bits
+ case Type::FP128TyID:
case Type::StructTyID:
case Type::ArrayTyID:
case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- // Special case for i128
- if (ETy->isIntegerTy(128)) {
+ // Special case for i128/fp128
+ if (ETy->getScalarSizeInBits() == 128) {
O << " .b8 ";
getSymbol(GVar)->print(O, MAI);
O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
continue;
}
- if (ShouldPassAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
@@ -1682,29 +1673,37 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
- int Bytes;
- // Integers of arbitrary width
- if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
- APInt Val = CI->getValue();
+ auto BufferConstant = [&](APInt Val) {
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
uint8_t Byte = Val.getLoBits(8).getZExtValue();
aggBuffer->addBytes(&Byte, 1, 1);
Val.lshrInPlace(8);
}
+ };
+
+ // Integers of arbitrary width
+ if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
+ BufferConstant(CI->getValue());
return;
}
+ // f128
+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
+ if (CFP->getType()->isFP128Ty()) {
+ BufferConstant(CFP->getValueAPF().bitcastToAPInt());
+ return;
+ }
+ }
+
// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
- if (CPV->getNumOperands())
- for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
- bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
+ for (const auto &Op : CPV->operands())
+ bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
return;
}
- if (const ConstantDataSequential *CDS =
- dyn_cast<ConstantDataSequential>(CPV)) {
+ if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
if (CDS->getNumElements())
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
@@ -1716,6 +1715,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
+ int Bytes;
if (i == (e - 1))
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST) -
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..8d26785b898f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
SmallVector<uint64_t, 16> TempOffsets;
// Special case for i128 - decompose to (i64, i64)
- if (Ty->isIntegerTy(128)) {
- ValueVTs.push_back(EVT(MVT::i64));
- ValueVTs.push_back(EVT(MVT::i64));
+ if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
+ ValueVTs.append({MVT::i64, MVT::i64});
- if (Offsets) {
- Offsets->push_back(StartingOffset + 0);
- Offsets->push_back(StartingOffset + 8);
- }
+ if (Offsets)
+ Offsets->append({StartingOffset + 0, StartingOffset + 8});
return;
}
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}
-static bool IsTypePassedAsArray(const Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
- !IsTypePassedAsArray(retTy)) {
+ !shouldPassAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if (IsTypePassedAsArray(retTy)) {
+ } else if (shouldPassAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;
if (!Outs[OIdx].Flags.isByVal()) {
- if (IsTypePassedAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
Align ParamAlign =
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
bool NeedAlign; // Does argument declaration specify alignment?
- bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
+ const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
- if (!IsTypePassedAsArray(RetTy)) {
+ if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3344,7 +3336,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (theArgs[i]->use_empty()) {
// argument is dead
- if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
+ if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..b800445a3b19c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
!isKernelFunction(*F);
}
-bool Isv2x16VT(EVT VT) {
- return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
-}
-
} // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..2288241ec0178 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
-bool Isv2x16VT(EVT VT);
+inline bool Isv2x16VT(EVT VT) {
+ return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
+}
+
+inline bool shouldPassAsArray(Type *Ty) {
+ return Ty->isAggregateType() || Ty->isVectorTy() ||
+ Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
+}
namespace NVPTX {
inline std::string getValidPTXIdentifier(StringRef Name) {
diff --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
new file mode 100644
index 0000000000000..5b96f4978a7cb
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx64-unknown-cuda"
+
+define fp128 @identity(fp128 %x) {
+; CHECK-LABEL: identity(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
+; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
+; CHECK-NEXT: ret;
+ ret fp128 %x
+}
+
+define void @load_store(ptr %in, ptr %out) {
+; CHECK-LABEL: load_store(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
+; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
+; CHECK-NEXT: ld.u64 %rd3, [%rd1];
+; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
+; CHECK-NEXT: st.u64 [%rd4], %rd3;
+; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT: ret;
+ %val = load fp128, ptr %in
+ store fp128 %val, ptr %out
+ ret void
+}
+
+define void @call(fp128 %x) {
+; CHECK-LABEL: call(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .align 16 .b8 param0[16];
+; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: call,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: ret;
+ call void @call(fp128 %x)
+ ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index e8d7fb3815b79..422f721d934e0 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -4,12 +4,15 @@
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
-; Check that we can handle global variables of large integer type.
+; Check that we can handle global variables of large integer and fp128 type.
; (lsb) 0x0102'0304'0506...0F10 (msb)
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};
+
; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
@large_data = global [4831838208 x i8] zeroinitializer
>From ab7d81fef13290a5e0c30a9d46ade99a40c097ee Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 20:33:29 +0000
Subject: [PATCH 2/3] address comments
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 35 +++++++++----------
.../test/CodeGen/NVPTX/global-variable-big.ll | 4 +--
2 files changed, 18 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index e0f9a1ada3bc4..0479278d8ea0d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -259,7 +259,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";
- auto PrintScalarParam = [&](unsigned Size) {
+ auto PrintScalarRetVal = [&](unsigned Size) {
O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
};
if (shouldPassAsArray(Ty)) {
@@ -269,11 +269,11 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
<< TotalSize << "]";
} else if (Ty->isFloatingPointTy()) {
- PrintScalarParam(Ty->getPrimitiveSizeInBits());
+ PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
} else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
- PrintScalarParam(ITy->getBitWidth());
+ PrintScalarRetVal(ITy->getBitWidth());
} else if (isa<PointerType>(Ty)) {
- PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
+ PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
@@ -1674,24 +1674,24 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
- auto BufferConstant = [&](APInt Val) {
- for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
+ auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
+ for (unsigned _ : llvm::seq(Val.getBitWidth() / 8)) {
uint8_t Byte = Val.getLoBits(8).getZExtValue();
- aggBuffer->addBytes(&Byte, 1, 1);
+ Buffer->addBytes(&Byte, 1, 1);
Val.lshrInPlace(8);
}
};
// Integers of arbitrary width
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
- BufferConstant(CI->getValue());
+ ExtendBuffer(CI->getValue(), aggBuffer);
return;
}
// f128
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
if (CFP->getType()->isFP128Ty()) {
- BufferConstant(CFP->getValueAPF().bitcastToAPInt());
+ ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
return;
}
}
@@ -1705,7 +1705,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
if (CDS->getNumElements())
- for (unsigned i = 0; i < CDS->getNumElements(); ++i)
+ for (unsigned i : llvm::seq(CDS->getNumElements()))
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
aggBuffer);
return;
@@ -1714,15 +1714,12 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
if (isa<ConstantStruct>(CPV)) {
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
- for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
- int Bytes;
- if (i == (e - 1))
- Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
- DL.getTypeAllocSize(ST) -
- DL.getStructLayout(ST)->getElementOffset(i);
- else
- Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
- DL.getStructLayout(ST)->getElementOffset(i);
+ for (unsigned i : llvm::seq(CPV->getNumOperands())) {
+ int EndOffset = (i + 1 == CPV->getNumOperands())
+ ? DL.getStructLayout(ST)->getElementOffset(0) +
+ DL.getTypeAllocSize(ST)
+ : DL.getStructLayout(ST)->getElementOffset(i + 1);
+ int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(i);
bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
}
}
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index 422f721d934e0..09f556e72a2bd 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -10,8 +10,8 @@ target triple = "nvptx64-nvidia-cuda"
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
- at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
-; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
>From a487d6b24065372d7883cdb62d076fbf1100b918 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 23:23:53 +0000
Subject: [PATCH 3/3] address comments
---
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 23 +++++++++--------------
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h | 6 ++++++
2 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0479278d8ea0d..2f4b109e8e9e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1675,11 +1675,8 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
const DataLayout &DL = getDataLayout();
auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
- for (unsigned _ : llvm::seq(Val.getBitWidth() / 8)) {
- uint8_t Byte = Val.getLoBits(8).getZExtValue();
- Buffer->addBytes(&Byte, 1, 1);
- Val.lshrInPlace(8);
- }
+ for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
+ Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
};
// Integers of arbitrary width
@@ -1704,23 +1701,21 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
}
if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
- if (CDS->getNumElements())
- for (unsigned i : llvm::seq(CDS->getNumElements()))
- bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
- aggBuffer);
+ for (unsigned I : llvm::seq(CDS->getNumElements()))
+ bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
return;
}
if (isa<ConstantStruct>(CPV)) {
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
- for (unsigned i : llvm::seq(CPV->getNumOperands())) {
- int EndOffset = (i + 1 == CPV->getNumOperands())
+ for (unsigned I : llvm::seq(CPV->getNumOperands())) {
+ int EndOffset = (I + 1 == CPV->getNumOperands())
? DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST)
- : DL.getStructLayout(ST)->getElementOffset(i + 1);
- int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(i);
- bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
+ : DL.getStructLayout(ST)->getElementOffset(I + 1);
+ int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
+ bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
}
}
return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index 74daaa2fb7134..ce19f5554bfbe 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -125,6 +125,12 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
return curpos;
}
+ void addByte(uint8_t Byte) {
+ assert((curpos + 1) <= size);
+ buffer[curpos] = Byte;
+ curpos++;
+ }
+
unsigned addZeros(int Num) {
assert((curpos + Num) <= size);
for (int i = 0; i < Num; ++i) {
More information about the llvm-commits
mailing list