[llvm] 3001387 - [NVPTX] Basic support for fp128 as a storage type (#136006)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 17 13:19:51 PDT 2025
Author: Alex MacLean
Date: 2025-04-17T13:19:48-07:00
New Revision: 30013872190ca05eb00333adb989c9f74b1cf3ac
URL: https://github.com/llvm/llvm-project/commit/30013872190ca05eb00333adb989c9f74b1cf3ac
DIFF: https://github.com/llvm/llvm-project/commit/30013872190ca05eb00333adb989c9f74b1cf3ac.diff
LOG: [NVPTX] Basic support for fp128 as a storage type (#136006)
While fp128 operations are not natively supported in hardware, emulation
for them is supported by nvcc. This change adds basic support for
fp128 as a storage type allowing for lowering of IR containing these
types.
Fixes: https://github.com/llvm/llvm-project/issues/95471
Added:
llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
Modified:
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
llvm/lib/Target/NVPTX/NVPTXUtilities.h
llvm/test/CodeGen/NVPTX/global-variable-big.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 65cfeadc21a3b..2f4b109e8e9e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}
-static bool ShouldPassAsArray(Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";
- if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
- !ShouldPassAsArray(Ty)) {
- unsigned size = 0;
- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
- size = ITy->getBitWidth();
- } else {
- assert(Ty->isFloatingPointTy() && "Floating point type expected here");
- size = Ty->getPrimitiveSizeInBits();
- }
- size = promoteScalarArgumentSize(size);
- O << ".param .b" << size << " func_retval0";
- } else if (isa<PointerType>(Ty)) {
- O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
- << " func_retval0";
- } else if (ShouldPassAsArray(Ty)) {
- unsigned totalsz = DL.getTypeAllocSize(Ty);
- Align RetAlignment = TLI->getFunctionArgumentAlignment(
+ auto PrintScalarRetVal = [&](unsigned Size) {
+ O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
+ };
+ if (shouldPassAsArray(Ty)) {
+ const unsigned TotalSize = DL.getTypeAllocSize(Ty);
+ const Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
- << totalsz << "]";
+ << TotalSize << "]";
+ } else if (Ty->isFloatingPointTy()) {
+ PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
+ } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+ PrintScalarRetVal(ITy->getBitWidth());
+ } else if (isa<PointerType>(Ty)) {
+ PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
- (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
+ if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
+ ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
case Type::IntegerTyID: // Integers larger than 64 bits
+ case Type::FP128TyID:
case Type::StructTyID:
case Type::ArrayTyID:
case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
- // Special case for i128
- if (ETy->isIntegerTy(128)) {
+ // Special case for i128/fp128
+ if (ETy->getScalarSizeInBits() == 128) {
O << " .b8 ";
getSymbol(GVar)->print(O, MAI);
O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
continue;
}
- if (ShouldPassAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
@@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
- int Bytes;
+
+ auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
+ for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
+ Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
+ };
// Integers of arbitrary width
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
- APInt Val = CI->getValue();
- for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
- uint8_t Byte = Val.getLoBits(8).getZExtValue();
- aggBuffer->addBytes(&Byte, 1, 1);
- Val.lshrInPlace(8);
- }
+ ExtendBuffer(CI->getValue(), aggBuffer);
return;
}
+ // f128
+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
+ if (CFP->getType()->isFP128Ty()) {
+ ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
+ return;
+ }
+ }
+
// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
- if (CPV->getNumOperands())
- for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
- bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
+ for (const auto &Op : CPV->operands())
+ bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
return;
}
- if (const ConstantDataSequential *CDS =
- dyn_cast<ConstantDataSequential>(CPV)) {
- if (CDS->getNumElements())
- for (unsigned i = 0; i < CDS->getNumElements(); ++i)
- bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
- aggBuffer);
+ if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
+ for (unsigned I : llvm::seq(CDS->getNumElements()))
+ bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
return;
}
if (isa<ConstantStruct>(CPV)) {
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
- for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
- if (i == (e - 1))
- Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
- DL.getTypeAllocSize(ST) -
- DL.getStructLayout(ST)->getElementOffset(i);
- else
- Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
- DL.getStructLayout(ST)->getElementOffset(i);
- bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
+ for (unsigned I : llvm::seq(CPV->getNumOperands())) {
+ int EndOffset = (I + 1 == CPV->getNumOperands())
+ ? DL.getStructLayout(ST)->getElementOffset(0) +
+ DL.getTypeAllocSize(ST)
+ : DL.getStructLayout(ST)->getElementOffset(I + 1);
+ int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
+ bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
}
}
return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index 74daaa2fb7134..9ed7e650e7b0c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -111,27 +111,22 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
// Copy Num bytes from Ptr.
// if Bytes > Num, zero fill up to Bytes.
- unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
- assert((curpos + Num) <= size);
- assert((curpos + Bytes) <= size);
- for (int i = 0; i < Num; ++i) {
- buffer[curpos] = Ptr[i];
- curpos++;
- }
- for (int i = Num; i < Bytes; ++i) {
- buffer[curpos] = 0;
- curpos++;
- }
- return curpos;
+ void addBytes(const unsigned char *Ptr, unsigned Num, unsigned Bytes) {
+ for (unsigned I : llvm::seq(Num))
+ addByte(Ptr[I]);
+ if (Bytes > Num)
+ addZeros(Bytes - Num);
}
- unsigned addZeros(int Num) {
- assert((curpos + Num) <= size);
- for (int i = 0; i < Num; ++i) {
- buffer[curpos] = 0;
- curpos++;
- }
- return curpos;
+ void addByte(uint8_t Byte) {
+ assert(curpos < size);
+ buffer[curpos] = Byte;
+ curpos++;
+ }
+
+ void addZeros(unsigned Num) {
+ for (unsigned _ : llvm::seq(Num))
+ addByte(0);
}
void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index abe4c27009698..277a34173e7b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
SmallVector<uint64_t, 16> TempOffsets;
// Special case for i128 - decompose to (i64, i64)
- if (Ty->isIntegerTy(128)) {
- ValueVTs.push_back(EVT(MVT::i64));
- ValueVTs.push_back(EVT(MVT::i64));
+ if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
+ ValueVTs.append({MVT::i64, MVT::i64});
- if (Offsets) {
- Offsets->push_back(StartingOffset + 0);
- Offsets->push_back(StartingOffset + 8);
- }
+ if (Offsets)
+ Offsets->append({StartingOffset + 0, StartingOffset + 8});
return;
}
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}
-static bool IsTypePassedAsArray(const Type *Ty) {
- return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
- Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
- !IsTypePassedAsArray(retTy)) {
+ !shouldPassAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if (IsTypePassedAsArray(retTy)) {
+ } else if (shouldPassAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;
if (!Outs[OIdx].Flags.isByVal()) {
- if (IsTypePassedAsArray(Ty)) {
+ if (shouldPassAsArray(Ty)) {
Align ParamAlign =
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
bool NeedAlign; // Does argument declaration specify alignment?
- bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
+ const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
- if (!IsTypePassedAsArray(RetTy)) {
+ if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3362,7 +3354,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (theArgs[i]->use_empty()) {
// argument is dead
- if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
+ if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..b800445a3b19c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
!isKernelFunction(*F);
}
-bool Isv2x16VT(EVT VT) {
- return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
-}
-
} // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..2288241ec0178 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
-bool Isv2x16VT(EVT VT);
+inline bool Isv2x16VT(EVT VT) {
+ return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
+}
+
+inline bool shouldPassAsArray(Type *Ty) {
+ return Ty->isAggregateType() || Ty->isVectorTy() ||
+ Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
+}
namespace NVPTX {
inline std::string getValidPTXIdentifier(StringRef Name) {
diff --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
new file mode 100644
index 0000000000000..5b96f4978a7cb
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx64-unknown-cuda"
+
+define fp128 @identity(fp128 %x) {
+; CHECK-LABEL: identity(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
+; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
+; CHECK-NEXT: ret;
+ ret fp128 %x
+}
+
+define void @load_store(ptr %in, ptr %out) {
+; CHECK-LABEL: load_store(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
+; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
+; CHECK-NEXT: ld.u64 %rd3, [%rd1];
+; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
+; CHECK-NEXT: st.u64 [%rd4], %rd3;
+; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT: ret;
+ %val = load fp128, ptr %in
+ store fp128 %val, ptr %out
+ ret void
+}
+
+define void @call(fp128 %x) {
+; CHECK-LABEL: call(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
+; CHECK-NEXT: { // callseq 0, 0
+; CHECK-NEXT: .param .align 16 .b8 param0[16];
+; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
+; CHECK-NEXT: call.uni
+; CHECK-NEXT: call,
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: );
+; CHECK-NEXT: } // callseq 0
+; CHECK-NEXT: ret;
+ call void @call(fp128 %x)
+ ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index e8d7fb3815b79..09f556e72a2bd 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -4,12 +4,15 @@
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
-; Check that we can handle global variables of large integer type.
+; Check that we can handle global variables of large integer and fp128 type.
; (lsb) 0x0102'0304'0506...0F10 (msb)
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+
; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
@large_data = global [4831838208 x i8] zeroinitializer
More information about the llvm-commits
mailing list