[llvm] 3001387 - [NVPTX] Basic support for fp128 as a storage type (#136006)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 13:19:51 PDT 2025


Author: Alex MacLean
Date: 2025-04-17T13:19:48-07:00
New Revision: 30013872190ca05eb00333adb989c9f74b1cf3ac

URL: https://github.com/llvm/llvm-project/commit/30013872190ca05eb00333adb989c9f74b1cf3ac
DIFF: https://github.com/llvm/llvm-project/commit/30013872190ca05eb00333adb989c9f74b1cf3ac.diff

LOG: [NVPTX] Basic support for fp128 as a storage type (#136006)

While fp128 operations are not natively supported in hardware, emulation
for them is supported by nvcc. This change adds basic support for
fp128 as a storage type allowing for lowering of IR containing these
types.

Fixes: https://github.com/llvm/llvm-project/issues/95471

Added: 
    llvm/test/CodeGen/NVPTX/fp128-storage-type.ll

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
    llvm/lib/Target/NVPTX/NVPTXUtilities.h
    llvm/test/CodeGen/NVPTX/global-variable-big.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 65cfeadc21a3b..2f4b109e8e9e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
   return MCOperand::createExpr(Expr);
 }
 
-static bool ShouldPassAsArray(Type *Ty) {
-  return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
-         Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
     return;
   O << " (";
 
-  if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
-      !ShouldPassAsArray(Ty)) {
-    unsigned size = 0;
-    if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
-      size = ITy->getBitWidth();
-    } else {
-      assert(Ty->isFloatingPointTy() && "Floating point type expected here");
-      size = Ty->getPrimitiveSizeInBits();
-    }
-    size = promoteScalarArgumentSize(size);
-    O << ".param .b" << size << " func_retval0";
-  } else if (isa<PointerType>(Ty)) {
-    O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
-      << " func_retval0";
-  } else if (ShouldPassAsArray(Ty)) {
-    unsigned totalsz = DL.getTypeAllocSize(Ty);
-    Align RetAlignment = TLI->getFunctionArgumentAlignment(
+  auto PrintScalarRetVal = [&](unsigned Size) {
+    O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
+  };
+  if (shouldPassAsArray(Ty)) {
+    const unsigned TotalSize = DL.getTypeAllocSize(Ty);
+    const Align RetAlignment = TLI->getFunctionArgumentAlignment(
         F, Ty, AttributeList::ReturnIndex, DL);
     O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
-      << totalsz << "]";
+      << TotalSize << "]";
+  } else if (Ty->isFloatingPointTy()) {
+    PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
+  } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+    PrintScalarRetVal(ITy->getBitWidth());
+  } else if (isa<PointerType>(Ty)) {
+    PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
   } else
     llvm_unreachable("Unknown return type");
   O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   O << " .align "
     << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
-  if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
-      (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
+  if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
+                             ETy->getScalarSizeInBits() <= 64)) {
     O << " .";
     // Special case: ABI requires that we use .u8 for predicates
     if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     // and vectors are lowered into arrays of bytes.
     switch (ETy->getTypeID()) {
     case Type::IntegerTyID: // Integers larger than 64 bits
+    case Type::FP128TyID:
     case Type::StructTyID:
     case Type::ArrayTyID:
     case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << " .align "
     << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
-  // Special case for i128
-  if (ETy->isIntegerTy(128)) {
+  // Special case for i128/fp128
+  if (ETy->getScalarSizeInBits() == 128) {
     O << " .b8 ";
     getSymbol(GVar)->print(O, MAI);
     O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       continue;
     }
 
-    if (ShouldPassAsArray(Ty)) {
+    if (shouldPassAsArray(Ty)) {
       // Just print .param .align <a> .b8 .param[size];
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
@@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
                                               AggBuffer *aggBuffer) {
   const DataLayout &DL = getDataLayout();
-  int Bytes;
+
+  auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
+    for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
+      Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
+  };
 
   // Integers of arbitrary width
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
-    APInt Val = CI->getValue();
-    for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
-      uint8_t Byte = Val.getLoBits(8).getZExtValue();
-      aggBuffer->addBytes(&Byte, 1, 1);
-      Val.lshrInPlace(8);
-    }
+    ExtendBuffer(CI->getValue(), aggBuffer);
     return;
   }
 
+  // f128
+  if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
+    if (CFP->getType()->isFP128Ty()) {
+      ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
+      return;
+    }
+  }
+
   // Old constants
   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
-    if (CPV->getNumOperands())
-      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
-        bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
+    for (const auto &Op : CPV->operands())
+      bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
     return;
   }
 
-  if (const ConstantDataSequential *CDS =
-          dyn_cast<ConstantDataSequential>(CPV)) {
-    if (CDS->getNumElements())
-      for (unsigned i = 0; i < CDS->getNumElements(); ++i)
-        bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
-                     aggBuffer);
+  if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
+    for (unsigned I : llvm::seq(CDS->getNumElements()))
+      bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
     return;
   }
 
   if (isa<ConstantStruct>(CPV)) {
     if (CPV->getNumOperands()) {
       StructType *ST = cast<StructType>(CPV->getType());
-      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
-        if (i == (e - 1))
-          Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
-                  DL.getTypeAllocSize(ST) -
-                  DL.getStructLayout(ST)->getElementOffset(i);
-        else
-          Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
-                  DL.getStructLayout(ST)->getElementOffset(i);
-        bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
+      for (unsigned I : llvm::seq(CPV->getNumOperands())) {
+        int EndOffset = (I + 1 == CPV->getNumOperands())
+                            ? DL.getStructLayout(ST)->getElementOffset(0) +
+                                  DL.getTypeAllocSize(ST)
+                            : DL.getStructLayout(ST)->getElementOffset(I + 1);
+        int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
+        bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
       }
     }
     return;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index 74daaa2fb7134..9ed7e650e7b0c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -111,27 +111,22 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
 
     // Copy Num bytes from Ptr.
     // if Bytes > Num, zero fill up to Bytes.
-    unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
-      assert((curpos + Num) <= size);
-      assert((curpos + Bytes) <= size);
-      for (int i = 0; i < Num; ++i) {
-        buffer[curpos] = Ptr[i];
-        curpos++;
-      }
-      for (int i = Num; i < Bytes; ++i) {
-        buffer[curpos] = 0;
-        curpos++;
-      }
-      return curpos;
+    void addBytes(const unsigned char *Ptr, unsigned Num, unsigned Bytes) {
+      for (unsigned I : llvm::seq(Num))
+        addByte(Ptr[I]);
+      if (Bytes > Num)
+        addZeros(Bytes - Num);
     }
 
-    unsigned addZeros(int Num) {
-      assert((curpos + Num) <= size);
-      for (int i = 0; i < Num; ++i) {
-        buffer[curpos] = 0;
-        curpos++;
-      }
-      return curpos;
+    void addByte(uint8_t Byte) {
+      assert(curpos < size);
+      buffer[curpos] = Byte;
+      curpos++;
+    }
+
+    void addZeros(unsigned Num) {
+      for (unsigned _ : llvm::seq(Num))
+        addByte(0);
     }
 
     void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index abe4c27009698..277a34173e7b8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   SmallVector<uint64_t, 16> TempOffsets;
 
   // Special case for i128 - decompose to (i64, i64)
-  if (Ty->isIntegerTy(128)) {
-    ValueVTs.push_back(EVT(MVT::i64));
-    ValueVTs.push_back(EVT(MVT::i64));
+  if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
+    ValueVTs.append({MVT::i64, MVT::i64});
 
-    if (Offsets) {
-      Offsets->push_back(StartingOffset + 0);
-      Offsets->push_back(StartingOffset + 8);
-    }
+    if (Offsets)
+      Offsets->append({StartingOffset + 0, StartingOffset + 8});
 
     return;
   }
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
 }
 
-static bool IsTypePassedAsArray(const Type *Ty) {
-  return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
-         Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
   } else {
     O << "(";
     if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
-        !IsTypePassedAsArray(retTy)) {
+        !shouldPassAsArray(retTy)) {
       unsigned size = 0;
       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
         size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
-    } else if (IsTypePassedAsArray(retTy)) {
+    } else if (shouldPassAsArray(retTy)) {
       O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
         << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
     } else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
     first = false;
 
     if (!Outs[OIdx].Flags.isByVal()) {
-      if (IsTypePassedAsArray(Ty)) {
+      if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
             getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
         O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
 
     bool NeedAlign; // Does argument declaration specify alignment?
-    bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
+    const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
     if (IsVAArg) {
       if (ParamCount == FirstVAArg) {
         SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     //  .param .align N .b8 retval0[<size-in-bytes>], or
     //  .param .b<size-in-bits> retval0
     unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
-    if (!IsTypePassedAsArray(RetTy)) {
+    if (!shouldPassAsArray(RetTy)) {
       resultsz = promoteScalarArgumentSize(resultsz);
       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3362,7 +3354,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
     if (theArgs[i]->use_empty()) {
       // argument is dead
-      if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
+      if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
         SmallVector<EVT, 16> vtparts;
 
         ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..b800445a3b19c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
          !isKernelFunction(*F);
 }
 
-bool Isv2x16VT(EVT VT) {
-  return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
-}
-
 } // namespace llvm

diff  --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..2288241ec0178 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
 
 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
 
-bool Isv2x16VT(EVT VT);
+inline bool Isv2x16VT(EVT VT) {
+  return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
+}
+
+inline bool shouldPassAsArray(Type *Ty) {
+  return Ty->isAggregateType() || Ty->isVectorTy() ||
+         Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
+}
 
 namespace NVPTX {
 inline std::string getValidPTXIdentifier(StringRef Name) {

diff  --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
new file mode 100644
index 0000000000000..5b96f4978a7cb
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx64-unknown-cuda"
+
+define fp128 @identity(fp128 %x) {
+; CHECK-LABEL: identity(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd1, %rd2};
+; CHECK-NEXT:    ret;
+  ret fp128 %x
+}
+
+define void @load_store(ptr %in, ptr %out) {
+; CHECK-LABEL: load_store(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [load_store_param_0];
+; CHECK-NEXT:    ld.u64 %rd2, [%rd1+8];
+; CHECK-NEXT:    ld.u64 %rd3, [%rd1];
+; CHECK-NEXT:    ld.param.u64 %rd4, [load_store_param_1];
+; CHECK-NEXT:    st.u64 [%rd4], %rd3;
+; CHECK-NEXT:    st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT:    ret;
+  %val = load fp128, ptr %in
+  store fp128 %val, ptr %out
+  ret void
+}
+
+define void @call(fp128 %x) {
+; CHECK-LABEL: call(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 16 .b8 param0[16];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd1, %rd2};
+; CHECK-NEXT:    call.uni
+; CHECK-NEXT:    call,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    ret;
+  call void @call(fp128 %x)
+  ret void
+}

diff  --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index e8d7fb3815b79..09f556e72a2bd 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -4,12 +4,15 @@
 target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
 target triple = "nvptx64-nvidia-cuda"
 
-; Check that we can handle global variables of large integer type.
+; Check that we can handle global variables of large integer and fp128 type.
 
 ; (lsb) 0x0102'0304'0506...0F10 (msb)
 @gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
 ; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
 
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+
 ; Make sure that we do not overflow on large number of elements.
 ; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
 @large_data = global [4831838208 x i8] zeroinitializer


        


More information about the llvm-commits mailing list