[llvm] [NVPTX] Basic support for fp128 as a storage type (PR #136006)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 16 16:24:07 PDT 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/136006

>From db1baa6daffd6d7b466325c14dc87715b4ca4bfd Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 15:50:52 +0000
Subject: [PATCH 1/3] [NVPTX] Basic support for fp128 as a storage type

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     | 74 +++++++++----------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 28 +++----
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp      |  4 -
 llvm/lib/Target/NVPTX/NVPTXUtilities.h        |  9 ++-
 llvm/test/CodeGen/NVPTX/fp128-storage-type.ll | 56 ++++++++++++++
 .../test/CodeGen/NVPTX/global-variable-big.ll |  5 +-
 6 files changed, 115 insertions(+), 61 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/fp128-storage-type.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 65cfeadc21a3b..e0f9a1ada3bc4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
   return MCOperand::createExpr(Expr);
 }
 
-static bool ShouldPassAsArray(Type *Ty) {
-  return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
-         Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
     return;
   O << " (";
 
-  if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
-      !ShouldPassAsArray(Ty)) {
-    unsigned size = 0;
-    if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
-      size = ITy->getBitWidth();
-    } else {
-      assert(Ty->isFloatingPointTy() && "Floating point type expected here");
-      size = Ty->getPrimitiveSizeInBits();
-    }
-    size = promoteScalarArgumentSize(size);
-    O << ".param .b" << size << " func_retval0";
-  } else if (isa<PointerType>(Ty)) {
-    O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
-      << " func_retval0";
-  } else if (ShouldPassAsArray(Ty)) {
-    unsigned totalsz = DL.getTypeAllocSize(Ty);
-    Align RetAlignment = TLI->getFunctionArgumentAlignment(
+  auto PrintScalarParam = [&](unsigned Size) {
+    O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
+  };
+  if (shouldPassAsArray(Ty)) {
+    const unsigned TotalSize = DL.getTypeAllocSize(Ty);
+    const Align RetAlignment = TLI->getFunctionArgumentAlignment(
         F, Ty, AttributeList::ReturnIndex, DL);
     O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
-      << totalsz << "]";
+      << TotalSize << "]";
+  } else if (Ty->isFloatingPointTy()) {
+    PrintScalarParam(Ty->getPrimitiveSizeInBits());
+  } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+    PrintScalarParam(ITy->getBitWidth());
+  } else if (isa<PointerType>(Ty)) {
+    PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
   } else
     llvm_unreachable("Unknown return type");
   O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   O << " .align "
     << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
-  if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
-      (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
+  if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
+                             ETy->getScalarSizeInBits() <= 64)) {
     O << " .";
     // Special case: ABI requires that we use .u8 for predicates
     if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     // and vectors are lowered into arrays of bytes.
     switch (ETy->getTypeID()) {
     case Type::IntegerTyID: // Integers larger than 64 bits
+    case Type::FP128TyID:
     case Type::StructTyID:
     case Type::ArrayTyID:
     case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << " .align "
     << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
-  // Special case for i128
-  if (ETy->isIntegerTy(128)) {
+  // Special case for i128/fp128
+  if (ETy->getScalarSizeInBits() == 128) {
     O << " .b8 ";
     getSymbol(GVar)->print(O, MAI);
     O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       continue;
     }
 
-    if (ShouldPassAsArray(Ty)) {
+    if (shouldPassAsArray(Ty)) {
       // Just print .param .align <a> .b8 .param[size];
       // <a>  = optimal alignment for the element type; always multiple of
       //        PAL.getParamAlignment
@@ -1682,29 +1673,37 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
                                               AggBuffer *aggBuffer) {
   const DataLayout &DL = getDataLayout();
-  int Bytes;
 
-  // Integers of arbitrary width
-  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
-    APInt Val = CI->getValue();
+  auto BufferConstant = [&](APInt Val) {
     for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
       uint8_t Byte = Val.getLoBits(8).getZExtValue();
       aggBuffer->addBytes(&Byte, 1, 1);
       Val.lshrInPlace(8);
     }
+  };
+
+  // Integers of arbitrary width
+  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
+    BufferConstant(CI->getValue());
     return;
   }
 
+  // f128
+  if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
+    if (CFP->getType()->isFP128Ty()) {
+      BufferConstant(CFP->getValueAPF().bitcastToAPInt());
+      return;
+    }
+  }
+
   // Old constants
   if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
-    if (CPV->getNumOperands())
-      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
-        bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
+    for (const auto &Op : CPV->operands())
+      bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
     return;
   }
 
-  if (const ConstantDataSequential *CDS =
-          dyn_cast<ConstantDataSequential>(CPV)) {
+  if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
     if (CDS->getNumElements())
       for (unsigned i = 0; i < CDS->getNumElements(); ++i)
         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
@@ -1716,6 +1715,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
     if (CPV->getNumOperands()) {
       StructType *ST = cast<StructType>(CPV->getType());
       for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
+        int Bytes;
         if (i == (e - 1))
           Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
                   DL.getTypeAllocSize(ST) -
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..8d26785b898f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
   SmallVector<uint64_t, 16> TempOffsets;
 
   // Special case for i128 - decompose to (i64, i64)
-  if (Ty->isIntegerTy(128)) {
-    ValueVTs.push_back(EVT(MVT::i64));
-    ValueVTs.push_back(EVT(MVT::i64));
+  if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
+    ValueVTs.append({MVT::i64, MVT::i64});
 
-    if (Offsets) {
-      Offsets->push_back(StartingOffset + 0);
-      Offsets->push_back(StartingOffset + 8);
-    }
+    if (Offsets)
+      Offsets->append({StartingOffset + 0, StartingOffset + 8});
 
     return;
   }
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
 }
 
-static bool IsTypePassedAsArray(const Type *Ty) {
-  return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
-         Ty->isHalfTy() || Ty->isBFloatTy();
-}
-
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
     const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
   } else {
     O << "(";
     if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
-        !IsTypePassedAsArray(retTy)) {
+        !shouldPassAsArray(retTy)) {
       unsigned size = 0;
       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
         size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
-    } else if (IsTypePassedAsArray(retTy)) {
+    } else if (shouldPassAsArray(retTy)) {
       O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
         << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
     } else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
     first = false;
 
     if (!Outs[OIdx].Flags.isByVal()) {
-      if (IsTypePassedAsArray(Ty)) {
+      if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
             getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
         O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
 
     bool NeedAlign; // Does argument declaration specify alignment?
-    bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
+    const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
     if (IsVAArg) {
       if (ParamCount == FirstVAArg) {
         SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     //  .param .align N .b8 retval0[<size-in-bytes>], or
     //  .param .b<size-in-bits> retval0
     unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
-    if (!IsTypePassedAsArray(RetTy)) {
+    if (!shouldPassAsArray(RetTy)) {
       resultsz = promoteScalarArgumentSize(resultsz);
       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3344,7 +3336,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
     if (theArgs[i]->use_empty()) {
       // argument is dead
-      if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
+      if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
         SmallVector<EVT, 16> vtparts;
 
         ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..b800445a3b19c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
          !isKernelFunction(*F);
 }
 
-bool Isv2x16VT(EVT VT) {
-  return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
-}
-
 } // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..2288241ec0178 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
 
 bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
 
-bool Isv2x16VT(EVT VT);
+inline bool Isv2x16VT(EVT VT) {
+  return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
+}
+
+inline bool shouldPassAsArray(Type *Ty) {
+  return Ty->isAggregateType() || Ty->isVectorTy() ||
+         Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
+}
 
 namespace NVPTX {
 inline std::string getValidPTXIdentifier(StringRef Name) {
diff --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
new file mode 100644
index 0000000000000..5b96f4978a7cb
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
+
+target triple = "nvptx64-unknown-cuda"
+
+define fp128 @identity(fp128 %x) {
+; CHECK-LABEL: identity(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
+; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd1, %rd2};
+; CHECK-NEXT:    ret;
+  ret fp128 %x
+}
+
+define void @load_store(ptr %in, ptr %out) {
+; CHECK-LABEL: load_store(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u64 %rd1, [load_store_param_0];
+; CHECK-NEXT:    ld.u64 %rd2, [%rd1+8];
+; CHECK-NEXT:    ld.u64 %rd3, [%rd1];
+; CHECK-NEXT:    ld.param.u64 %rd4, [load_store_param_1];
+; CHECK-NEXT:    st.u64 [%rd4], %rd3;
+; CHECK-NEXT:    st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT:    ret;
+  %val = load fp128, ptr %in
+  store fp128 %val, ptr %out
+  ret void
+}
+
+define void @call(fp128 %x) {
+; CHECK-LABEL: call(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .align 16 .b8 param0[16];
+; CHECK-NEXT:    st.param.v2.b64 [param0], {%rd1, %rd2};
+; CHECK-NEXT:    call.uni
+; CHECK-NEXT:    call,
+; CHECK-NEXT:    (
+; CHECK-NEXT:    param0
+; CHECK-NEXT:    );
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    ret;
+  call void @call(fp128 %x)
+  ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index e8d7fb3815b79..422f721d934e0 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -4,12 +4,15 @@
 target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
 target triple = "nvptx64-nvidia-cuda"
 
-; Check that we can handle global variables of large integer type.
+; Check that we can handle global variables of large integer and fp128 type.
 
 ; (lsb) 0x0102'0304'0506...0F10 (msb)
 @gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
 ; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
 
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};
+
 ; Make sure that we do not overflow on large number of elements.
 ; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
 @large_data = global [4831838208 x i8] zeroinitializer

>From ab7d81fef13290a5e0c30a9d46ade99a40c097ee Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 20:33:29 +0000
Subject: [PATCH 2/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     | 35 +++++++++----------
 .../test/CodeGen/NVPTX/global-variable-big.ll |  4 +--
 2 files changed, 18 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index e0f9a1ada3bc4..0479278d8ea0d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -259,7 +259,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
     return;
   O << " (";
 
-  auto PrintScalarParam = [&](unsigned Size) {
+  auto PrintScalarRetVal = [&](unsigned Size) {
     O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
   };
   if (shouldPassAsArray(Ty)) {
@@ -269,11 +269,11 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
     O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
       << TotalSize << "]";
   } else if (Ty->isFloatingPointTy()) {
-    PrintScalarParam(Ty->getPrimitiveSizeInBits());
+    PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
   } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
-    PrintScalarParam(ITy->getBitWidth());
+    PrintScalarRetVal(ITy->getBitWidth());
   } else if (isa<PointerType>(Ty)) {
-    PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
+    PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
   } else
     llvm_unreachable("Unknown return type");
   O << ") ";
@@ -1674,24 +1674,24 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
                                               AggBuffer *aggBuffer) {
   const DataLayout &DL = getDataLayout();
 
-  auto BufferConstant = [&](APInt Val) {
-    for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
+  auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
+    for (unsigned _ : llvm::seq(Val.getBitWidth() / 8)) {
       uint8_t Byte = Val.getLoBits(8).getZExtValue();
-      aggBuffer->addBytes(&Byte, 1, 1);
+      Buffer->addBytes(&Byte, 1, 1);
       Val.lshrInPlace(8);
     }
   };
 
   // Integers of arbitrary width
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
-    BufferConstant(CI->getValue());
+    ExtendBuffer(CI->getValue(), aggBuffer);
     return;
   }
 
   // f128
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
     if (CFP->getType()->isFP128Ty()) {
-      BufferConstant(CFP->getValueAPF().bitcastToAPInt());
+      ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
       return;
     }
   }
@@ -1705,7 +1705,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
 
   if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
     if (CDS->getNumElements())
-      for (unsigned i = 0; i < CDS->getNumElements(); ++i)
+      for (unsigned i : llvm::seq(CDS->getNumElements()))
         bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
                      aggBuffer);
     return;
@@ -1714,15 +1714,12 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
   if (isa<ConstantStruct>(CPV)) {
     if (CPV->getNumOperands()) {
       StructType *ST = cast<StructType>(CPV->getType());
-      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
-        int Bytes;
-        if (i == (e - 1))
-          Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
-                  DL.getTypeAllocSize(ST) -
-                  DL.getStructLayout(ST)->getElementOffset(i);
-        else
-          Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
-                  DL.getStructLayout(ST)->getElementOffset(i);
+      for (unsigned i : llvm::seq(CPV->getNumOperands())) {
+        int EndOffset = (i + 1 == CPV->getNumOperands())
+                            ? DL.getStructLayout(ST)->getElementOffset(0) +
+                                  DL.getTypeAllocSize(ST)
+                            : DL.getStructLayout(ST)->getElementOffset(i + 1);
+        int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(i);
         bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
       }
     }
diff --git a/llvm/test/CodeGen/NVPTX/global-variable-big.ll b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
index 422f721d934e0..09f556e72a2bd 100644
--- a/llvm/test/CodeGen/NVPTX/global-variable-big.ll
+++ b/llvm/test/CodeGen/NVPTX/global-variable-big.ll
@@ -10,8 +10,8 @@ target triple = "nvptx64-nvidia-cuda"
 @gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
 ; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
 
- at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
-; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};
+ at gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
+; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
 
 ; Make sure that we do not overflow on large number of elements.
 ; CHECK: .visible .global .align 1 .b8 large_data[4831838208];

>From a487d6b24065372d7883cdb62d076fbf1100b918 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 16 Apr 2025 23:23:53 +0000
Subject: [PATCH 3/3] address comments

---
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 23 +++++++++--------------
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h   |  6 ++++++
 2 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0479278d8ea0d..2f4b109e8e9e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1675,11 +1675,8 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
   const DataLayout &DL = getDataLayout();
 
   auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
-    for (unsigned _ : llvm::seq(Val.getBitWidth() / 8)) {
-      uint8_t Byte = Val.getLoBits(8).getZExtValue();
-      Buffer->addBytes(&Byte, 1, 1);
-      Val.lshrInPlace(8);
-    }
+    for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
+      Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
   };
 
   // Integers of arbitrary width
@@ -1704,23 +1701,21 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
   }
 
   if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
-    if (CDS->getNumElements())
-      for (unsigned i : llvm::seq(CDS->getNumElements()))
-        bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
-                     aggBuffer);
+    for (unsigned I : llvm::seq(CDS->getNumElements()))
+      bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
     return;
   }
 
   if (isa<ConstantStruct>(CPV)) {
     if (CPV->getNumOperands()) {
       StructType *ST = cast<StructType>(CPV->getType());
-      for (unsigned i : llvm::seq(CPV->getNumOperands())) {
-        int EndOffset = (i + 1 == CPV->getNumOperands())
+      for (unsigned I : llvm::seq(CPV->getNumOperands())) {
+        int EndOffset = (I + 1 == CPV->getNumOperands())
                             ? DL.getStructLayout(ST)->getElementOffset(0) +
                                   DL.getTypeAllocSize(ST)
-                            : DL.getStructLayout(ST)->getElementOffset(i + 1);
-        int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(i);
-        bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
+                            : DL.getStructLayout(ST)->getElementOffset(I + 1);
+        int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
+        bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
       }
     }
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index 74daaa2fb7134..ce19f5554bfbe 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -125,6 +125,12 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
       return curpos;
     }
 
+    void addByte(uint8_t Byte) {
+      assert((curpos + 1) <= size);
+      buffer[curpos] = Byte;
+      curpos++;
+    }
+
     unsigned addZeros(int Num) {
       assert((curpos + Num) <= size);
       for (int i = 0; i < Num; ++i) {



More information about the llvm-commits mailing list