[llvm] [X86] getConstantFromPool - add basic handling for non-zero address offsets (PR #127225)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 14 09:02:37 PST 2025
https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/127225
As detailed on #127047 - getConstantFromPool can't handle cases where the constant pool load address offset is non-zero
This patch add an optional pointer argument to store the offset allowing users that can handle it to correctly extract the offseted sub-constant data.
This is initially just handled by X86FixupVectorConstantsPass which uses it to extract the offset constant bits - we don't have thorough test coverage for this yet, so I've only added it for the simpler sext/zext/zmovl cases
>From d91a03560cbfcb351dec8a5c36a8df1bd81c8ede Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 14 Feb 2025 17:00:27 +0000
Subject: [PATCH] [X86] getConstantFromPool - add basic handling for non-zero
address offsets
As detailed on #127047 - getConstantFromPool can't handle cases where the constant pool load address offset is non-zero
This patch add an optional pointer argument to store the offset allowing users that can handle it to correctly extract the offset sub-constant data
This is initially just handled by X86FixupVectorConstantsPass which uses it to extract the offset constant bits - we don't have thorough test coverage for this yet, so I've only added it for the simpler sext/zext/zmovl cases
---
.../Target/X86/X86FixupVectorConstants.cpp | 65 +++++++++++++------
llvm/lib/Target/X86/X86InstrInfo.cpp | 11 +++-
llvm/lib/Target/X86/X86InstrInfo.h | 5 +-
.../vector-interleaved-load-i32-stride-7.ll | 8 +--
.../vector-interleaved-load-i32-stride-8.ll | 8 +--
5 files changed, 65 insertions(+), 32 deletions(-)
diff --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 40024baf93fdb..457255884f60e 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -139,8 +139,16 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
}
static std::optional<APInt> extractConstantBits(const Constant *C,
- unsigned NumBits) {
+ int64_t ByteOffset) {
+ int64_t BitOffset = ByteOffset * 8;
if (std::optional<APInt> Bits = extractConstantBits(C))
+ return Bits->extractBits(Bits->getBitWidth() - BitOffset, BitOffset);
+ return std::nullopt;
+}
+
+static std::optional<APInt>
+extractConstantBits(const Constant *C, int64_t ByteOffset, unsigned NumBits) {
+ if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset))
return Bits->zextOrTrunc(NumBits);
return std::nullopt;
}
@@ -148,11 +156,16 @@ static std::optional<APInt> extractConstantBits(const Constant *C,
// Attempt to compute the splat width of bits data by normalizing the splat to
// remove undefs.
static std::optional<APInt> getSplatableConstant(const Constant *C,
+ int64_t ByteOffset,
unsigned SplatBitWidth) {
const Type *Ty = C->getType();
assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
"Illegal splat width");
+ // TODO: Add ByteOffset support once we have test coverage.
+ if (ByteOffset != 0)
+ return std::nullopt;
+
if (std::optional<APInt> Bits = extractConstantBits(C))
if (Bits->isSplat(SplatBitWidth))
return Bits->trunc(SplatBitWidth);
@@ -241,10 +254,12 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
// Attempt to rebuild a normalized splat vector constant of the requested splat
// width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
- unsigned /*NumElts*/, unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, int64_t ByteOffset,
+ unsigned /*NumBits*/, unsigned /*NumElts*/,
+ unsigned SplatBitWidth) {
// TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
- std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
+ std::optional<APInt> Splat =
+ getSplatableConstant(C, ByteOffset, SplatBitWidth);
if (!Splat)
return nullptr;
@@ -263,8 +278,8 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
}
-static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
- unsigned /*NumElts*/,
+static Constant *rebuildZeroUpperCst(const Constant *C, int64_t ByteOffset,
+ unsigned NumBits, unsigned /*NumElts*/,
unsigned ScalarBitWidth) {
Type *SclTy = C->getType()->getScalarType();
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
@@ -272,7 +287,8 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
if (NumBits > ScalarBitWidth) {
// Determine if the upper bits are all zero.
- if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
+ if (std::optional<APInt> Bits =
+ extractConstantBits(C, ByteOffset, NumBits)) {
if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
// If the original constant was made of smaller elements, try to retain
// those types.
@@ -290,14 +306,14 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
}
static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
- unsigned NumBits, unsigned NumElts,
- unsigned SrcEltBitWidth) {
+ int64_t ByteOffset, unsigned NumBits,
+ unsigned NumElts, unsigned SrcEltBitWidth) {
unsigned DstEltBitWidth = NumBits / NumElts;
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
- if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
+ if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset, NumBits)) {
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
"Unexpected constant extension");
@@ -319,13 +335,15 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
return nullptr;
}
-static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
- unsigned NumElts, unsigned SrcEltBitWidth) {
- return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
+static Constant *rebuildSExtCst(const Constant *C, int64_t ByteOffset,
+ unsigned NumBits, unsigned NumElts,
+ unsigned SrcEltBitWidth) {
+ return rebuildExtCst(C, true, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
}
-static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
- unsigned NumElts, unsigned SrcEltBitWidth) {
- return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
+static Constant *rebuildZExtCst(const Constant *C, int64_t ByteOffset,
+ unsigned NumBits, unsigned NumElts,
+ unsigned SrcEltBitWidth) {
+ return rebuildExtCst(C, false, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
}
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -344,7 +362,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
int Op;
int NumCstElts;
int MemBitWidth;
- std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
+ std::function<Constant *(const Constant *, int64_t, unsigned, unsigned,
+ unsigned)>
RebuildConstant;
};
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
@@ -359,19 +378,23 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
#endif
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
"Unexpected number of operands!");
- if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+ int64_t ByteOffset = 0;
+ if (auto *C = X86::getConstantFromPool(MI, OperandNo, &ByteOffset)) {
unsigned CstBitWidth = C->getType()->getPrimitiveSizeInBits();
RegBitWidth = RegBitWidth ? RegBitWidth : CstBitWidth;
for (const FixupEntry &Fixup : Fixups) {
- if (Fixup.Op) {
+ if (Fixup.Op && 0 <= ByteOffset &&
+ (RegBitWidth + (8 * ByteOffset)) <= CstBitWidth) {
// Construct a suitable constant and adjust the MI to use the new
// constant pool entry.
- if (Constant *NewCst = Fixup.RebuildConstant(
- C, RegBitWidth, Fixup.NumCstElts, Fixup.MemBitWidth)) {
+ if (Constant *NewCst =
+ Fixup.RebuildConstant(C, ByteOffset, RegBitWidth,
+ Fixup.NumCstElts, Fixup.MemBitWidth)) {
unsigned NewCPI =
CP->getConstantPoolIndex(NewCst, Align(Fixup.MemBitWidth / 8));
MI.setDesc(TII->get(Fixup.Op));
MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
+ MI.getOperand(OperandNo + X86::AddrDisp).setOffset(0);
return true;
}
}
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 44db5b6865c42..968f7ecd1b12b 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -3656,7 +3656,7 @@ int X86::getFirstAddrOperandIdx(const MachineInstr &MI) {
}
const Constant *X86::getConstantFromPool(const MachineInstr &MI,
- unsigned OpNo) {
+ unsigned OpNo, int64_t *ByteOffset) {
assert(MI.getNumOperands() >= (OpNo + X86::AddrNumOperands) &&
"Unexpected number of operands!");
@@ -3665,7 +3665,11 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
return nullptr;
const MachineOperand &Disp = MI.getOperand(OpNo + X86::AddrDisp);
- if (!Disp.isCPI() || Disp.getOffset() != 0)
+ if (!Disp.isCPI())
+ return nullptr;
+
+ int64_t Offset = Disp.getOffset();
+ if (Offset != 0 && !ByteOffset)
return nullptr;
ArrayRef<MachineConstantPoolEntry> Constants =
@@ -3677,6 +3681,9 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
if (ConstantEntry.isMachineConstantPoolEntry())
return nullptr;
+ if (ByteOffset)
+ *ByteOffset = Offset;
+
return ConstantEntry.Val.ConstVal;
}
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 5f87e02fe67c4..35f3d7f1d319f 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -112,7 +112,10 @@ bool isX87Instruction(MachineInstr &MI);
int getFirstAddrOperandIdx(const MachineInstr &MI);
/// Find any constant pool entry associated with a specific instruction operand.
-const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo);
+/// By default returns null if the address offset is non-zero, but will return
+/// the entry if \p ByteOffset is non-null to store the value.
+const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo,
+ int64_t *ByteOffset = nullptr);
} // namespace X86
diff --git a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
index 955a7ffcec795..90d4fb3dd4aa9 100644
--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
@@ -242,7 +242,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
; AVX512-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
; AVX512-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
; AVX512-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
; AVX512-FCP-NEXT: vmovq %xmm3, (%rsi)
; AVX512-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -307,7 +307,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
; AVX512DQ-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
; AVX512DQ-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512DQ-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
; AVX512DQ-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
; AVX512DQ-FCP-NEXT: vmovq %xmm3, (%rsi)
; AVX512DQ-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -372,7 +372,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
; AVX512BW-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
; AVX512BW-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
; AVX512BW-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
; AVX512BW-FCP-NEXT: vmovq %xmm3, (%rsi)
; AVX512BW-FCP-NEXT: vmovq %xmm4, (%rdx)
@@ -437,7 +437,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
; AVX512DQ-BW-FCP-NEXT: vmovdqa 32(%rdi), %ymm7
; AVX512DQ-BW-FCP-NEXT: vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512DQ-BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
; AVX512DQ-BW-FCP-NEXT: vpermps %zmm0, %zmm2, %zmm0
; AVX512DQ-BW-FCP-NEXT: vmovq %xmm3, (%rsi)
; AVX512DQ-BW-FCP-NEXT: vmovq %xmm4, (%rdx)
diff --git a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
index 13410fb5cc4b8..964b8feda1901 100644
--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
@@ -226,7 +226,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512-FCP-NEXT: vmovaps (%rdi), %ymm4
; AVX512-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
; AVX512-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
-; AVX512-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
; AVX512-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
; AVX512-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
; AVX512-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -291,7 +291,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512DQ-FCP-NEXT: vmovaps (%rdi), %ymm4
; AVX512DQ-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
; AVX512DQ-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
-; AVX512DQ-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512DQ-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
; AVX512DQ-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
; AVX512DQ-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
; AVX512DQ-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -356,7 +356,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512BW-FCP-NEXT: vmovaps (%rdi), %ymm4
; AVX512BW-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
; AVX512BW-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
-; AVX512BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
; AVX512BW-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
; AVX512BW-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
; AVX512BW-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
@@ -421,7 +421,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
; AVX512DQ-BW-FCP-NEXT: vmovaps (%rdi), %ymm4
; AVX512DQ-BW-FCP-NEXT: vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
; AVX512DQ-BW-FCP-NEXT: vextractf128 $1, %ymm5, %xmm5
-; AVX512DQ-BW-FCP-NEXT: vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512DQ-BW-FCP-NEXT: vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
; AVX512DQ-BW-FCP-NEXT: vpermps (%rdi), %zmm6, %zmm6
; AVX512DQ-BW-FCP-NEXT: vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
; AVX512DQ-BW-FCP-NEXT: vextractf128 $1, %ymm1, %xmm4
More information about the llvm-commits
mailing list