[llvm] b846613 - [X86] X86FixupVectorConstants - add destination register width to rebuildSplatCst/rebuildZeroUpperCst/rebuildExtCst callbacks
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 8 08:35:26 PST 2024
Author: Simon Pilgrim
Date: 2024-02-08T16:35:13Z
New Revision: b846613837d83989d99d33f4b90db7bad019aa8c
URL: https://github.com/llvm/llvm-project/commit/b846613837d83989d99d33f4b90db7bad019aa8c
DIFF: https://github.com/llvm/llvm-project/commit/b846613837d83989d99d33f4b90db7bad019aa8c.diff
LOG: [X86] X86FixupVectorConstants - add destination register width to rebuildSplatCst/rebuildZeroUpperCst/rebuildExtCst callbacks
As found on #81136 - we aren't correctly handling for cases where the constant pool entry is wider than the destination register width, causing incorrect scaling of the truncated constant for load-extension cases.
This first patch just pulls out the destination register width argument, its still currently driven by the constant pool entry but that will be addressed in a followup.
Added:
Modified:
llvm/lib/Target/X86/X86FixupVectorConstants.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 9c46cee572fc9..9b90b5e4bc1ea 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -121,6 +121,13 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
return std::nullopt;
}
+static std::optional<APInt> extractConstantBits(const Constant *C,
+ unsigned NumBits) {
+ if (std::optional<APInt> Bits = extractConstantBits(C))
+ return Bits->zextOrTrunc(NumBits);
+ return std::nullopt;
+}
+
// Attempt to compute the splat width of bits data by normalizing the splat to
// remove undefs.
static std::optional<APInt> getSplatableConstant(const Constant *C,
@@ -217,16 +224,15 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
// Attempt to rebuild a normalized splat vector constant of the requested splat
// width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
- unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
+ unsigned /*NumElts*/, unsigned SplatBitWidth) {
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
if (!Splat)
return nullptr;
// Determine scalar size to use for the constant splat vector, clamping as we
// might have found a splat smaller than the original constant data.
- const Type *OriginalType = C->getType();
- Type *SclTy = OriginalType->getScalarType();
+ Type *SclTy = C->getType()->getScalarType();
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
@@ -236,20 +242,19 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
: 64;
// Extract per-element bits.
- return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
+ return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
}
-static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
+static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
+ unsigned /*NumElts*/,
unsigned ScalarBitWidth) {
- Type *Ty = C->getType();
- Type *SclTy = Ty->getScalarType();
- unsigned NumBits = Ty->getPrimitiveSizeInBits();
+ Type *SclTy = C->getType()->getScalarType();
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
LLVMContext &Ctx = C->getContext();
if (NumBits > ScalarBitWidth) {
// Determine if the upper bits are all zero.
- if (std::optional<APInt> Bits = extractConstantBits(C)) {
+ if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
// If the original constant was made of smaller elements, try to retain
// those types.
@@ -266,16 +271,15 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
return nullptr;
}
-static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
+static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
+ unsigned NumBits, unsigned NumElts,
unsigned SrcEltBitWidth) {
- Type *Ty = C->getType();
- unsigned NumBits = Ty->getPrimitiveSizeInBits();
unsigned DstEltBitWidth = NumBits / NumElts;
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
- if (std::optional<APInt> Bits = extractConstantBits(C)) {
+ if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
"Unexpected constant extension");
@@ -290,19 +294,20 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
}
+ Type *Ty = C->getType();
return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
SrcEltBitWidth);
}
return nullptr;
}
-static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
- unsigned SrcEltBitWidth) {
- return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
+static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
+ unsigned NumElts, unsigned SrcEltBitWidth) {
+ return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
}
-static Constant *rebuildZExtCst(const Constant *C, unsigned NumElts,
- unsigned SrcEltBitWidth) {
- return rebuildExtCst(C, false, NumElts, SrcEltBitWidth);
+static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
+ unsigned NumElts, unsigned SrcEltBitWidth) {
+ return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
}
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -320,7 +325,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
int Op;
int NumCstElts;
int BitWidth;
- std::function<Constant *(const Constant *, unsigned, unsigned)>
+ std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
RebuildConstant;
};
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
@@ -335,12 +340,13 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
"Unexpected number of operands!");
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+ unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
for (const FixupEntry &Fixup : Fixups) {
if (Fixup.Op) {
// Construct a suitable constant and adjust the MI to use the new
// constant pool entry.
- if (Constant *NewCst =
- Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
+ if (Constant *NewCst = Fixup.RebuildConstant(
+ C, NumBits, Fixup.NumCstElts, Fixup.BitWidth)) {
unsigned NewCPI =
CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
MI.setDesc(TII->get(Fixup.Op));
More information about the llvm-commits
mailing list