[llvm] [X86] getConstantFromPool - add basic handling for non-zero address offsets (PR #127225)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 14 09:02:37 PST 2025


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/127225

As detailed on #127047 - getConstantFromPool can't handle cases where the constant pool load address offset is non-zero

This patch add an optional pointer argument to store the offset allowing users that can handle it to correctly extract the offseted sub-constant data.

This is initially just handled by X86FixupVectorConstantsPass which uses it to extract the offset constant bits - we don't have thorough test coverage for this yet, so I've only added it for the simpler sext/zext/zmovl cases

>From d91a03560cbfcb351dec8a5c36a8df1bd81c8ede Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 14 Feb 2025 17:00:27 +0000
Subject: [PATCH] [X86] getConstantFromPool - add basic handling for non-zero
 address offsets

As detailed on #127047 - getConstantFromPool can't handle cases where the constant pool load address offset is non-zero

This patch add an optional pointer argument to store the offset allowing users that can handle it to correctly extract the offset sub-constant data

This is initially just handled by X86FixupVectorConstantsPass which uses it to extract the offset constant bits - we don't have thorough test coverage for this yet, so I've only added it for the simpler sext/zext/zmovl cases
---
 .../Target/X86/X86FixupVectorConstants.cpp    | 65 +++++++++++++------
 llvm/lib/Target/X86/X86InstrInfo.cpp          | 11 +++-
 llvm/lib/Target/X86/X86InstrInfo.h            |  5 +-
 .../vector-interleaved-load-i32-stride-7.ll   |  8 +--
 .../vector-interleaved-load-i32-stride-8.ll   |  8 +--
 5 files changed, 65 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 40024baf93fdb..457255884f60e 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -139,8 +139,16 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
 }
 
 static std::optional<APInt> extractConstantBits(const Constant *C,
-                                                unsigned NumBits) {
+                                                int64_t ByteOffset) {
+  int64_t BitOffset = ByteOffset * 8;
   if (std::optional<APInt> Bits = extractConstantBits(C))
+    return Bits->extractBits(Bits->getBitWidth() - BitOffset, BitOffset);
+  return std::nullopt;
+}
+
+static std::optional<APInt>
+extractConstantBits(const Constant *C, int64_t ByteOffset, unsigned NumBits) {
+  if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset))
     return Bits->zextOrTrunc(NumBits);
   return std::nullopt;
 }
@@ -148,11 +156,16 @@ static std::optional<APInt> extractConstantBits(const Constant *C,
 // Attempt to compute the splat width of bits data by normalizing the splat to
 // remove undefs.
 static std::optional<APInt> getSplatableConstant(const Constant *C,
+                                                 int64_t ByteOffset,
                                                  unsigned SplatBitWidth) {
   const Type *Ty = C->getType();
   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
          "Illegal splat width");
 
+  // TODO: Add ByteOffset support once we have test coverage.
+  if (ByteOffset != 0)
+    return std::nullopt;
+
   if (std::optional<APInt> Bits = extractConstantBits(C))
     if (Bits->isSplat(SplatBitWidth))
       return Bits->trunc(SplatBitWidth);
@@ -241,10 +254,12 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
 
 // Attempt to rebuild a normalized splat vector constant of the requested splat
 // width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
-                                 unsigned /*NumElts*/, unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, int64_t ByteOffset,
+                                 unsigned /*NumBits*/, unsigned /*NumElts*/,
+                                 unsigned SplatBitWidth) {
   // TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
-  std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
+  std::optional<APInt> Splat =
+      getSplatableConstant(C, ByteOffset, SplatBitWidth);
   if (!Splat)
     return nullptr;
 
@@ -263,8 +278,8 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
   return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
 }
 
-static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
-                                     unsigned /*NumElts*/,
+static Constant *rebuildZeroUpperCst(const Constant *C, int64_t ByteOffset,
+                                     unsigned NumBits, unsigned /*NumElts*/,
                                      unsigned ScalarBitWidth) {
   Type *SclTy = C->getType()->getScalarType();
   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
@@ -272,7 +287,8 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
 
   if (NumBits > ScalarBitWidth) {
     // Determine if the upper bits are all zero.
-    if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
+    if (std::optional<APInt> Bits =
+            extractConstantBits(C, ByteOffset, NumBits)) {
       if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
         // If the original constant was made of smaller elements, try to retain
         // those types.
@@ -290,14 +306,14 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
 }
 
 static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
-                               unsigned NumBits, unsigned NumElts,
-                               unsigned SrcEltBitWidth) {
+                               int64_t ByteOffset, unsigned NumBits,
+                               unsigned NumElts, unsigned SrcEltBitWidth) {
   unsigned DstEltBitWidth = NumBits / NumElts;
   assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
          (DstEltBitWidth % SrcEltBitWidth) == 0 &&
          (DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
 
-  if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
+  if (std::optional<APInt> Bits = extractConstantBits(C, ByteOffset, NumBits)) {
     assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
            (Bits->getBitWidth() % DstEltBitWidth) == 0 &&
            "Unexpected constant extension");
@@ -319,13 +335,15 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
 
   return nullptr;
 }
-static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
-                                unsigned NumElts, unsigned SrcEltBitWidth) {
-  return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
+static Constant *rebuildSExtCst(const Constant *C, int64_t ByteOffset,
+                                unsigned NumBits, unsigned NumElts,
+                                unsigned SrcEltBitWidth) {
+  return rebuildExtCst(C, true, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
 }
-static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
-                                unsigned NumElts, unsigned SrcEltBitWidth) {
-  return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
+static Constant *rebuildZExtCst(const Constant *C, int64_t ByteOffset,
+                                unsigned NumBits, unsigned NumElts,
+                                unsigned SrcEltBitWidth) {
+  return rebuildExtCst(C, false, ByteOffset, NumBits, NumElts, SrcEltBitWidth);
 }
 
 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -344,7 +362,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
     int Op;
     int NumCstElts;
     int MemBitWidth;
-    std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
+    std::function<Constant *(const Constant *, int64_t, unsigned, unsigned,
+                             unsigned)>
         RebuildConstant;
   };
   auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
@@ -359,19 +378,23 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
 #endif
     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
            "Unexpected number of operands!");
-    if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+    int64_t ByteOffset = 0;
+    if (auto *C = X86::getConstantFromPool(MI, OperandNo, &ByteOffset)) {
       unsigned CstBitWidth = C->getType()->getPrimitiveSizeInBits();
       RegBitWidth = RegBitWidth ? RegBitWidth : CstBitWidth;
       for (const FixupEntry &Fixup : Fixups) {
-        if (Fixup.Op) {
+        if (Fixup.Op && 0 <= ByteOffset &&
+            (RegBitWidth + (8 * ByteOffset)) <= CstBitWidth) {
           // Construct a suitable constant and adjust the MI to use the new
           // constant pool entry.
-          if (Constant *NewCst = Fixup.RebuildConstant(
-                  C, RegBitWidth, Fixup.NumCstElts, Fixup.MemBitWidth)) {
+          if (Constant *NewCst =
+                  Fixup.RebuildConstant(C, ByteOffset, RegBitWidth,
+                                        Fixup.NumCstElts, Fixup.MemBitWidth)) {
             unsigned NewCPI =
                 CP->getConstantPoolIndex(NewCst, Align(Fixup.MemBitWidth / 8));
             MI.setDesc(TII->get(Fixup.Op));
             MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
+            MI.getOperand(OperandNo + X86::AddrDisp).setOffset(0);
             return true;
           }
         }
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 44db5b6865c42..968f7ecd1b12b 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -3656,7 +3656,7 @@ int X86::getFirstAddrOperandIdx(const MachineInstr &MI) {
 }
 
 const Constant *X86::getConstantFromPool(const MachineInstr &MI,
-                                         unsigned OpNo) {
+                                         unsigned OpNo, int64_t *ByteOffset) {
   assert(MI.getNumOperands() >= (OpNo + X86::AddrNumOperands) &&
          "Unexpected number of operands!");
 
@@ -3665,7 +3665,11 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
     return nullptr;
 
   const MachineOperand &Disp = MI.getOperand(OpNo + X86::AddrDisp);
-  if (!Disp.isCPI() || Disp.getOffset() != 0)
+  if (!Disp.isCPI())
+    return nullptr;
+
+  int64_t Offset = Disp.getOffset();
+  if (Offset != 0 && !ByteOffset)
     return nullptr;
 
   ArrayRef<MachineConstantPoolEntry> Constants =
@@ -3677,6 +3681,9 @@ const Constant *X86::getConstantFromPool(const MachineInstr &MI,
   if (ConstantEntry.isMachineConstantPoolEntry())
     return nullptr;
 
+  if (ByteOffset)
+    *ByteOffset = Offset;
+
   return ConstantEntry.Val.ConstVal;
 }
 
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 5f87e02fe67c4..35f3d7f1d319f 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -112,7 +112,10 @@ bool isX87Instruction(MachineInstr &MI);
 int getFirstAddrOperandIdx(const MachineInstr &MI);
 
 /// Find any constant pool entry associated with a specific instruction operand.
-const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo);
+/// By default returns null if the address offset is non-zero, but will return
+/// the entry if \p ByteOffset is non-null to store the value.
+const Constant *getConstantFromPool(const MachineInstr &MI, unsigned OpNo,
+                                    int64_t *ByteOffset = nullptr);
 
 } // namespace X86
 
diff --git a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
index 955a7ffcec795..90d4fb3dd4aa9 100644
--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-7.ll
@@ -242,7 +242,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
 ; AVX512-FCP-NEXT:    vmovdqa 32(%rdi), %ymm7
 ; AVX512-FCP-NEXT:    vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
 ; AVX512-FCP-NEXT:    vpermps %zmm0, %zmm2, %zmm0
 ; AVX512-FCP-NEXT:    vmovq %xmm3, (%rsi)
 ; AVX512-FCP-NEXT:    vmovq %xmm4, (%rdx)
@@ -307,7 +307,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512DQ-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
 ; AVX512DQ-FCP-NEXT:    vmovdqa 32(%rdi), %ymm7
 ; AVX512DQ-FCP-NEXT:    vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512DQ-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512DQ-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
 ; AVX512DQ-FCP-NEXT:    vpermps %zmm0, %zmm2, %zmm0
 ; AVX512DQ-FCP-NEXT:    vmovq %xmm3, (%rsi)
 ; AVX512DQ-FCP-NEXT:    vmovq %xmm4, (%rdx)
@@ -372,7 +372,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
 ; AVX512BW-FCP-NEXT:    vmovdqa 32(%rdi), %ymm7
 ; AVX512BW-FCP-NEXT:    vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512BW-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
 ; AVX512BW-FCP-NEXT:    vpermps %zmm0, %zmm2, %zmm0
 ; AVX512BW-FCP-NEXT:    vmovq %xmm3, (%rsi)
 ; AVX512BW-FCP-NEXT:    vmovq %xmm4, (%rdx)
@@ -437,7 +437,7 @@ define void @load_i32_stride7_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512DQ-BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [13,4,6,7]
 ; AVX512DQ-BW-FCP-NEXT:    vmovdqa 32(%rdi), %ymm7
 ; AVX512DQ-BW-FCP-NEXT:    vpermt2d (%rdi), %ymm2, %ymm7
-; AVX512DQ-BW-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm2
+; AVX512DQ-BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm2 = [6,13,6,7]
 ; AVX512DQ-BW-FCP-NEXT:    vpermps %zmm0, %zmm2, %zmm0
 ; AVX512DQ-BW-FCP-NEXT:    vmovq %xmm3, (%rsi)
 ; AVX512DQ-BW-FCP-NEXT:    vmovq %xmm4, (%rdx)
diff --git a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
index 13410fb5cc4b8..964b8feda1901 100644
--- a/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
+++ b/llvm/test/CodeGen/X86/vector-interleaved-load-i32-stride-8.ll
@@ -226,7 +226,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512-FCP-NEXT:    vmovaps (%rdi), %ymm4
 ; AVX512-FCP-NEXT:    vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
 ; AVX512-FCP-NEXT:    vextractf128 $1, %ymm5, %xmm5
-; AVX512-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
 ; AVX512-FCP-NEXT:    vpermps (%rdi), %zmm6, %zmm6
 ; AVX512-FCP-NEXT:    vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
 ; AVX512-FCP-NEXT:    vextractf128 $1, %ymm1, %xmm4
@@ -291,7 +291,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512DQ-FCP-NEXT:    vmovaps (%rdi), %ymm4
 ; AVX512DQ-FCP-NEXT:    vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
 ; AVX512DQ-FCP-NEXT:    vextractf128 $1, %ymm5, %xmm5
-; AVX512DQ-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512DQ-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
 ; AVX512DQ-FCP-NEXT:    vpermps (%rdi), %zmm6, %zmm6
 ; AVX512DQ-FCP-NEXT:    vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
 ; AVX512DQ-FCP-NEXT:    vextractf128 $1, %ymm1, %xmm4
@@ -356,7 +356,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512BW-FCP-NEXT:    vmovaps (%rdi), %ymm4
 ; AVX512BW-FCP-NEXT:    vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
 ; AVX512BW-FCP-NEXT:    vextractf128 $1, %ymm5, %xmm5
-; AVX512BW-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
 ; AVX512BW-FCP-NEXT:    vpermps (%rdi), %zmm6, %zmm6
 ; AVX512BW-FCP-NEXT:    vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
 ; AVX512BW-FCP-NEXT:    vextractf128 $1, %ymm1, %xmm4
@@ -421,7 +421,7 @@ define void @load_i32_stride8_vf2(ptr %in.vec, ptr %out.vec0, ptr %out.vec1, ptr
 ; AVX512DQ-BW-FCP-NEXT:    vmovaps (%rdi), %ymm4
 ; AVX512DQ-BW-FCP-NEXT:    vunpcklps {{.*#+}} ymm5 = ymm4[0],ymm1[0],ymm4[1],ymm1[1],ymm4[4],ymm1[4],ymm4[5],ymm1[5]
 ; AVX512DQ-BW-FCP-NEXT:    vextractf128 $1, %ymm5, %xmm5
-; AVX512DQ-BW-FCP-NEXT:    vmovaps {{\.?LCPI[0-9]+_[0-9]+}}+16(%rip), %xmm6
+; AVX512DQ-BW-FCP-NEXT:    vpmovsxbd {{.*#+}} xmm6 = [5,13,5,5]
 ; AVX512DQ-BW-FCP-NEXT:    vpermps (%rdi), %zmm6, %zmm6
 ; AVX512DQ-BW-FCP-NEXT:    vunpckhps {{.*#+}} ymm1 = ymm4[2],ymm1[2],ymm4[3],ymm1[3],ymm4[6],ymm1[6],ymm4[7],ymm1[7]
 ; AVX512DQ-BW-FCP-NEXT:    vextractf128 $1, %ymm1, %xmm4



More information about the llvm-commits mailing list