[llvm] 2665b2a - [X86] Pull out combineConstantPoolLoads helper from combineLoad. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 10:06:39 PDT 2024


Author: Simon Pilgrim
Date: 2024-05-29T18:05:41+01:00
New Revision: 2665b2a6ddb1625799536c45ca15605a6f24c081

URL: https://github.com/llvm/llvm-project/commit/2665b2a6ddb1625799536c45ca15605a6f24c081
DIFF: https://github.com/llvm/llvm-project/commit/2665b2a6ddb1625799536c45ca15605a6f24c081.diff

LOG: [X86] Pull out combineConstantPoolLoads helper from combineLoad. NFC.

The logic is already pretty dense and a future patch will further complicate this.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2d8343ffa1a0b..24340e135b08b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -50823,10 +50823,83 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
+                                        SelectionDAG &DAG,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        const X86Subtarget &Subtarget) {
+  auto *Ld = cast<LoadSDNode>(N);
+  EVT RegVT = Ld->getValueType(0);
+  EVT MemVT = Ld->getMemoryVT();
+  SDValue Ptr = Ld->getBasePtr();
+  SDValue Chain = Ld->getChain();
+  ISD::LoadExtType Ext = Ld->getExtensionType();
+
+  if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
+    return SDValue();
+
+  if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
+    return SDValue();
+
+  auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
+                         ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
+    for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
+      if (Undefs[I])
+        continue;
+      if (UserUndefs[I] || Bits[I] != UserBits[I])
+        return false;
+    }
+    return true;
+  };
+
+  // Look through all other loads/broadcasts in the chain for another constant
+  // pool entry.
+  for (SDNode *User : Chain->uses()) {
+    auto *UserLd = dyn_cast<MemSDNode>(User);
+    if (User != N && UserLd &&
+        (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
+         User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
+         ISD::isNormalLoad(User)) &&
+        UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
+        User->getValueSizeInBits(0).getFixedValue() >
+            RegVT.getFixedSizeInBits()) {
+      EVT UserVT = User->getValueType(0);
+      SDValue UserPtr = UserLd->getBasePtr();
+      const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
+      const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
+
+      // See if we are loading a constant that matches in the lower
+      // bits of a longer constant (but from a 
diff erent constant pool ptr).
+      if (LdC && UserC && UserPtr != Ptr) {
+        unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
+        unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
+        if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
+          APInt Undefs, UserUndefs;
+          SmallVector<APInt> Bits, UserBits;
+          unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
+                                      UserVT.getScalarSizeInBits());
+          if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
+                                            Bits) &&
+              getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
+                                            UserUndefs, UserBits)) {
+            if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
+              SDValue Extract = extractSubVector(
+                  SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
+              Extract = DAG.getBitcast(RegVT, Extract);
+              return DCI.CombineTo(N, Extract, SDValue(User, 1));
+            }
+          }
+        }
+      }
+    }
+  }
+
+  return SDValue();
+}
+
 static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
                            TargetLowering::DAGCombinerInfo &DCI,
                            const X86Subtarget &Subtarget) {
-  LoadSDNode *Ld = cast<LoadSDNode>(N);
+  auto *Ld = cast<LoadSDNode>(N);
   EVT RegVT = Ld->getValueType(0);
   EVT MemVT = Ld->getMemoryVT();
   SDLoc dl(Ld);
@@ -50885,7 +50958,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // If we also load/broadcast this to a wider type, then just extract the
+  // If we also broadcast this vector to a wider type, then just extract the
   // lowest subvector.
   if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
       (RegVT.is128BitVector() || RegVT.is256BitVector())) {
@@ -50894,61 +50967,23 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
     for (SDNode *User : Chain->uses()) {
       auto *UserLd = dyn_cast<MemSDNode>(User);
       if (User != N && UserLd &&
-          (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
-           User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
-           ISD::isNormalLoad(User)) &&
-          UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
+          User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
+          UserLd->getChain() == Chain && UserLd->getBasePtr() == Ptr &&
+          UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits() &&
+          !User->hasAnyUseOfValue(1) &&
           User->getValueSizeInBits(0).getFixedValue() >
               RegVT.getFixedSizeInBits()) {
-        if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
-            UserLd->getBasePtr() == Ptr &&
-            UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
-          SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
-                                             RegVT.getSizeInBits());
-          Extract = DAG.getBitcast(RegVT, Extract);
-          return DCI.CombineTo(N, Extract, SDValue(User, 1));
-        }
-        auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
-                               ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
-          for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
-            if (Undefs[I])
-              continue;
-            if (UserUndefs[I] || Bits[I] != UserBits[I])
-              return false;
-          }
-          return true;
-        };
-        // See if we are loading a constant that matches in the lower
-        // bits of a longer constant (but from a 
diff erent constant pool ptr).
-        EVT UserVT = User->getValueType(0);
-        SDValue UserPtr = UserLd->getBasePtr();
-        const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
-        const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
-        if (LdC && UserC && UserPtr != Ptr) {
-          unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
-          unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
-          if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
-            APInt Undefs, UserUndefs;
-            SmallVector<APInt> Bits, UserBits;
-            unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
-                                        UserVT.getScalarSizeInBits());
-            if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
-                                              Bits) &&
-                getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
-                                              UserUndefs, UserBits)) {
-              if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
-                SDValue Extract = extractSubVector(
-                    SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
-                Extract = DAG.getBitcast(RegVT, Extract);
-                return DCI.CombineTo(N, Extract, SDValue(User, 1));
-              }
-            }
-          }
-        }
+        SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
+                                           RegVT.getSizeInBits());
+        Extract = DAG.getBitcast(RegVT, Extract);
+        return DCI.CombineTo(N, Extract, SDValue(User, 1));
       }
     }
   }
 
+  if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
+    return V;
+
   // Cast ptr32 and ptr64 pointers to the default address space before a load.
   unsigned AddrSpace = Ld->getAddressSpace();
   if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||


        


More information about the llvm-commits mailing list