[llvm] [NVPTX] Merge consecutive elements while buffering constant vectors with sub-byte datatype. (PR #183628)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 3 16:15:56 PDT 2026
================
@@ -1737,6 +1743,75 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
llvm_unreachable("unsupported constant type in printAggregateConstant()");
}
+void NVPTXAsmPrinter::bufferAggregateConstVec(const ConstantVector *CV,
+ AggBuffer *aggBuffer) {
+ unsigned NumElems = CV->getType()->getNumElements();
+ const unsigned BuffSize = aggBuffer->getBufferSize();
+
+ // Buffer one element at a time if we have allocated enough buffer space.
+ if (BuffSize >= NumElems) {
+ for (const auto &Op : CV->operands())
+ bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
+ return;
+ }
+
+ // Sub-byte datatypes will have more elements than bytes allocated for the
+ // buffer. Merge consecutive elements to form a full byte. We expect that 8 %
+ // sub-byte-elem-size should be 0 and current expected usage is for i4 (for
+ // e2m1-fp4 types).
+ Type *ElemTy = CV->getType()->getElementType();
+ assert(ElemTy->isIntegerTy() && "Expected integer data type.");
+ unsigned ElemTySize = ElemTy->getPrimitiveSizeInBits();
+ assert(ElemTySize < 8 && "Expected sub-byte data type.");
+ assert(8 % ElemTySize == 0 && "Element type size must evenly divide a byte.");
+ // Number of elements to merge to form a full byte.
+ unsigned NumElemsPerByte = 8 / ElemTySize;
+ unsigned NumCompleteBytes = NumElems / NumElemsPerByte;
+ unsigned NumTailElems = NumElems % NumElemsPerByte;
+
+ // Helper lambda to constant-fold sub-vector of sub-byte type elements into
+ // i8. Start and end indices of the sub-vector is provided, along with number
+ // of padding zeros if required.
+ auto ConvertSubCVtoInt8 = [this, &ElemTy](const ConstantVector *CV,
+ unsigned Start, unsigned End,
+ unsigned NumPaddingZeros = 0) {
+ // Collect elements to create sub-vector.
+ SmallVector<Constant *, 8> SubCVElems;
+ for (unsigned I : llvm::seq(Start, End))
+ SubCVElems.push_back(CV->getAggregateElement(I));
+
+ // Optionally pad with zeros.
+ for (auto _ : llvm::seq(NumPaddingZeros))
+ SubCVElems.push_back(ConstantInt::getNullValue(ElemTy));
+
+ auto SubCV = ConstantVector::get(SubCVElems);
+ Type *Int8Ty = IntegerType::get(SubCV->getContext(), 8);
+
+ // Merge elements of the sub-vector using ConstantFolding.
+ ConstantInt *MergedElem =
+ dyn_cast_or_null<ConstantInt>(ConstantFoldConstant(
+ ConstantExpr::getBitCast(const_cast<Constant *>(SubCV), Int8Ty),
+ getDataLayout()));
+
+ if (!MergedElem)
+ report_fatal_error(
+ "Cannot lower vector global with unusual element type");
+
+ return MergedElem;
+ };
+
+ // Iterate through elements of vector one chunk at a time and buffer that
+ // chunk.
+ for (unsigned I : llvm::seq(NumCompleteBytes))
+ bufferLEByte(ConvertSubCVtoInt8(CV, I, I + NumElemsPerByte), 0, aggBuffer);
----------------
Artem-B wrote:
Hmm. This does not look quite right. `I` iterates over bytes, but `ConvertSubCVtoInt8` needs a sub-byte _element_ index, so it should be `I*NumElemsPerByte, (I+1)*NumElemsPerByte`
It may make sense to call the loop variable `ByteIdx` to make the distinction obvious.
https://github.com/llvm/llvm-project/pull/183628
More information about the llvm-commits
mailing list