[llvm] b846613 - [X86] X86FixupVectorConstants - add destination register width to rebuildSplatCst/rebuildZeroUpperCst/rebuildExtCst callbacks

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 08:35:26 PST 2024


Author: Simon Pilgrim
Date: 2024-02-08T16:35:13Z
New Revision: b846613837d83989d99d33f4b90db7bad019aa8c

URL: https://github.com/llvm/llvm-project/commit/b846613837d83989d99d33f4b90db7bad019aa8c
DIFF: https://github.com/llvm/llvm-project/commit/b846613837d83989d99d33f4b90db7bad019aa8c.diff

LOG: [X86] X86FixupVectorConstants - add destination register width to rebuildSplatCst/rebuildZeroUpperCst/rebuildExtCst callbacks

As found on #81136 - we aren't correctly handling for cases where the constant pool entry is wider than the destination register width, causing incorrect scaling of the truncated constant for load-extension cases.

This first patch just pulls out the destination register width argument, its still currently driven by the constant pool entry but that will be addressed in a followup.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index 9c46cee572fc9..9b90b5e4bc1ea 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -121,6 +121,13 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
   return std::nullopt;
 }
 
+static std::optional<APInt> extractConstantBits(const Constant *C,
+                                                unsigned NumBits) {
+  if (std::optional<APInt> Bits = extractConstantBits(C))
+    return Bits->zextOrTrunc(NumBits);
+  return std::nullopt;
+}
+
 // Attempt to compute the splat width of bits data by normalizing the splat to
 // remove undefs.
 static std::optional<APInt> getSplatableConstant(const Constant *C,
@@ -217,16 +224,15 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
 
 // Attempt to rebuild a normalized splat vector constant of the requested splat
 // width, built up of potentially smaller scalar values.
-static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
-                                 unsigned SplatBitWidth) {
+static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
+                                 unsigned /*NumElts*/, unsigned SplatBitWidth) {
   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
   if (!Splat)
     return nullptr;
 
   // Determine scalar size to use for the constant splat vector, clamping as we
   // might have found a splat smaller than the original constant data.
-  const Type *OriginalType = C->getType();
-  Type *SclTy = OriginalType->getScalarType();
+  Type *SclTy = C->getType()->getScalarType();
   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
 
@@ -236,20 +242,19 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
                    : 64;
 
   // Extract per-element bits.
-  return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
+  return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
 }
 
-static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
+static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
+                                     unsigned /*NumElts*/,
                                      unsigned ScalarBitWidth) {
-  Type *Ty = C->getType();
-  Type *SclTy = Ty->getScalarType();
-  unsigned NumBits = Ty->getPrimitiveSizeInBits();
+  Type *SclTy = C->getType()->getScalarType();
   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
   LLVMContext &Ctx = C->getContext();
 
   if (NumBits > ScalarBitWidth) {
     // Determine if the upper bits are all zero.
-    if (std::optional<APInt> Bits = extractConstantBits(C)) {
+    if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
       if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
         // If the original constant was made of smaller elements, try to retain
         // those types.
@@ -266,16 +271,15 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
   return nullptr;
 }
 
-static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
+static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
+                               unsigned NumBits, unsigned NumElts,
                                unsigned SrcEltBitWidth) {
-  Type *Ty = C->getType();
-  unsigned NumBits = Ty->getPrimitiveSizeInBits();
   unsigned DstEltBitWidth = NumBits / NumElts;
   assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
          (DstEltBitWidth % SrcEltBitWidth) == 0 &&
          (DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
 
-  if (std::optional<APInt> Bits = extractConstantBits(C)) {
+  if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
     assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
            (Bits->getBitWidth() % DstEltBitWidth) == 0 &&
            "Unexpected constant extension");
@@ -290,19 +294,20 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
       TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
     }
 
+    Type *Ty = C->getType();
     return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
                            SrcEltBitWidth);
   }
 
   return nullptr;
 }
-static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
-                                unsigned SrcEltBitWidth) {
-  return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
+static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
+                                unsigned NumElts, unsigned SrcEltBitWidth) {
+  return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
 }
-static Constant *rebuildZExtCst(const Constant *C, unsigned NumElts,
-                                unsigned SrcEltBitWidth) {
-  return rebuildExtCst(C, false, NumElts, SrcEltBitWidth);
+static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
+                                unsigned NumElts, unsigned SrcEltBitWidth) {
+  return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
 }
 
 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -320,7 +325,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
     int Op;
     int NumCstElts;
     int BitWidth;
-    std::function<Constant *(const Constant *, unsigned, unsigned)>
+    std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
         RebuildConstant;
   };
   auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
@@ -335,12 +340,13 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
            "Unexpected number of operands!");
     if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
+      unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
       for (const FixupEntry &Fixup : Fixups) {
         if (Fixup.Op) {
           // Construct a suitable constant and adjust the MI to use the new
           // constant pool entry.
-          if (Constant *NewCst =
-                  Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
+          if (Constant *NewCst = Fixup.RebuildConstant(
+                  C, NumBits, Fixup.NumCstElts, Fixup.BitWidth)) {
             unsigned NewCPI =
                 CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
             MI.setDesc(TII->get(Fixup.Op));


        


More information about the llvm-commits mailing list