[llvm] b476e6a - [DAGCombiner] add helper function for store merging of loaded values; NFC

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 9 14:26:22 PDT 2020


Author: Sanjay Patel
Date: 2020-07-09T17:20:04-04:00
New Revision: b476e6a642d08aefaf391df026893078cd6ea9b2

URL: https://github.com/llvm/llvm-project/commit/b476e6a642d08aefaf391df026893078cd6ea9b2
DIFF: https://github.com/llvm/llvm-project/commit/b476e6a642d08aefaf391df026893078cd6ea9b2.diff

LOG: [DAGCombiner] add helper function for store merging of loaded values; NFC

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 406af4f18549..309125035d34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -717,7 +717,7 @@ namespace {
     /// store chains that are composed entirely of constant values.
     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
                                   unsigned NumConsecutiveStores,
-                                  EVT MemVT, bool AllowVectors, SDNode *Root);
+                                  EVT MemVT, SDNode *Root, bool AllowVectors);
 
     /// This is a helper function for mergeConsecutiveStores. It is used for
     /// store chains that are composed entirely of extracted vector elements.
@@ -727,6 +727,13 @@ namespace {
                                  unsigned NumConsecutiveStores, EVT MemVT,
                                  SDNode *Root);
 
+    /// This is a helper function for mergeConsecutiveStores. It is used for
+    /// store chains that are composed entirely of loaded values.
+    bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
+                              unsigned NumConsecutiveStores, EVT MemVT,
+                              SDNode *Root, bool AllowVectors,
+                              bool IsNonTemporalStore, bool IsNonTemporalLoad);
+
     /// Merge consecutive store operations into a wide store.
     /// This optimization uses wide integers or vectors when possible.
     /// \return true if stores were merged.
@@ -16299,7 +16306,7 @@ DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
 
 bool DAGCombiner::tryStoreMergeOfConstants(
     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
-    EVT MemVT, bool AllowVectors, SDNode *RootNode) {
+    EVT MemVT, SDNode *RootNode, bool AllowVectors) {
   LLVMContext &Context = *DAG.getContext();
   const DataLayout &DL = DAG.getDataLayout();
   int64_t ElementSizeBytes = MemVT.getStoreSize();
@@ -16489,6 +16496,261 @@ bool DAGCombiner::tryStoreMergeOfExtracts(
   return MadeChange;
 }
 
+bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
+                                       unsigned NumConsecutiveStores, EVT MemVT,
+                                       SDNode *RootNode, bool AllowVectors,
+                                       bool IsNonTemporalStore,
+                                       bool IsNonTemporalLoad) {
+  LLVMContext &Context = *DAG.getContext();
+  const DataLayout &DL = DAG.getDataLayout();
+  int64_t ElementSizeBytes = MemVT.getStoreSize();
+  unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
+  bool MadeChange = false;
+
+  int64_t StartAddress = StoreNodes[0].OffsetFromBase;
+
+  // Look for load nodes which are used by the stored values.
+  SmallVector<MemOpLink, 8> LoadNodes;
+
+  // Find acceptable loads. Loads need to have the same chain (token factor),
+  // must not be zext, volatile, indexed, and they must be consecutive.
+  BaseIndexOffset LdBasePtr;
+
+  for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+    StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
+    SDValue Val = peekThroughBitcasts(St->getValue());
+    LoadSDNode *Ld = cast<LoadSDNode>(Val);
+
+    BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
+    // If this is not the first ptr that we check.
+    int64_t LdOffset = 0;
+    if (LdBasePtr.getBase().getNode()) {
+      // The base ptr must be the same.
+      if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
+        break;
+    } else {
+      // Check that all other base pointers are the same as this one.
+      LdBasePtr = LdPtr;
+    }
+
+    // We found a potential memory operand to merge.
+    LoadNodes.push_back(MemOpLink(Ld, LdOffset));
+  }
+
+  while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
+    // If we have load/store pair instructions and we only have two values,
+    // don't bother merging.
+    Align RequiredAlignment;
+    if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
+        StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
+      StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
+      LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
+      break;
+    }
+    LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+    unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+    unsigned FirstStoreAlign = FirstInChain->getAlignment();
+    LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
+    unsigned FirstLoadAlign = FirstLoad->getAlignment();
+
+    // Scan the memory operations on the chain and find the first
+    // non-consecutive load memory address. These variables hold the index in
+    // the store node array.
+
+    unsigned LastConsecutiveLoad = 1;
+
+    // This variable refers to the size and not index in the array.
+    unsigned LastLegalVectorType = 1;
+    unsigned LastLegalIntegerType = 1;
+    bool isDereferenceable = true;
+    bool DoIntegerTruncate = false;
+    StartAddress = LoadNodes[0].OffsetFromBase;
+    SDValue FirstChain = FirstLoad->getChain();
+    for (unsigned i = 1; i < LoadNodes.size(); ++i) {
+      // All loads must share the same chain.
+      if (LoadNodes[i].MemNode->getChain() != FirstChain)
+        break;
+
+      int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
+      if (CurrAddress - StartAddress != (ElementSizeBytes * i))
+        break;
+      LastConsecutiveLoad = i;
+
+      if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
+        isDereferenceable = false;
+
+      // Find a legal type for the vector store.
+      unsigned Elts = (i + 1) * NumMemElts;
+      EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
+
+      // Break early when size is too large to be legal.
+      if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
+        break;
+
+      bool IsFastSt = false;
+      bool IsFastLd = false;
+      if (TLI.isTypeLegal(StoreTy) &&
+          TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+          TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                 *FirstInChain->getMemOperand(), &IsFastSt) &&
+          IsFastSt &&
+          TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                 *FirstLoad->getMemOperand(), &IsFastLd) &&
+          IsFastLd) {
+        LastLegalVectorType = i + 1;
+      }
+
+      // Find a legal type for the integer store.
+      unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
+      StoreTy = EVT::getIntegerVT(Context, SizeInBits);
+      if (TLI.isTypeLegal(StoreTy) &&
+          TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+          TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                 *FirstInChain->getMemOperand(), &IsFastSt) &&
+          IsFastSt &&
+          TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                 *FirstLoad->getMemOperand(), &IsFastLd) &&
+          IsFastLd) {
+        LastLegalIntegerType = i + 1;
+        DoIntegerTruncate = false;
+        // Or check whether a truncstore and extload is legal.
+      } else if (TLI.getTypeAction(Context, StoreTy) ==
+                 TargetLowering::TypePromoteInteger) {
+        EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
+        if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
+            TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
+            TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
+            TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
+            TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
+            TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                   *FirstInChain->getMemOperand(), &IsFastSt) &&
+            IsFastSt &&
+            TLI.allowsMemoryAccess(Context, DL, StoreTy,
+                                   *FirstLoad->getMemOperand(), &IsFastLd) &&
+            IsFastLd) {
+          LastLegalIntegerType = i + 1;
+          DoIntegerTruncate = true;
+        }
+      }
+    }
+
+    // Only use vector types if the vector type is larger than the integer
+    // type. If they are the same, use integers.
+    bool UseVectorTy =
+        LastLegalVectorType > LastLegalIntegerType && AllowVectors;
+    unsigned LastLegalType =
+        std::max(LastLegalVectorType, LastLegalIntegerType);
+
+    // We add +1 here because the LastXXX variables refer to location while
+    // the NumElem refers to array/index size.
+    unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
+    NumElem = std::min(LastLegalType, NumElem);
+
+    if (NumElem < 2) {
+      // We know that candidate stores are in order and of correct
+      // shape. While there is no mergeable sequence from the
+      // beginning one may start later in the sequence. The only
+      // reason a merge of size N could have failed where another of
+      // the same size would not have is if the alignment or either
+      // the load or store has improved. Drop as many candidates as we
+      // can here.
+      unsigned NumSkip = 1;
+      while ((NumSkip < LoadNodes.size()) &&
+             (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) &&
+             (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
+        NumSkip++;
+      StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
+      LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
+      NumConsecutiveStores -= NumSkip;
+      continue;
+    }
+
+    // Check that we can merge these candidates without causing a cycle.
+    if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
+                                                  RootNode)) {
+      StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+      LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
+      NumConsecutiveStores -= NumElem;
+      continue;
+    }
+
+    // Find if it is better to use vectors or integers to load and store
+    // to memory.
+    EVT JointMemOpVT;
+    if (UseVectorTy) {
+      // Find a legal type for the vector store.
+      unsigned Elts = NumElem * NumMemElts;
+      JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
+    } else {
+      unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
+      JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
+    }
+
+    SDLoc LoadDL(LoadNodes[0].MemNode);
+    SDLoc StoreDL(StoreNodes[0].MemNode);
+
+    // The merged loads are required to have the same incoming chain, so
+    // using the first's chain is acceptable.
+
+    SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
+    AddToWorklist(NewStoreChain.getNode());
+
+    MachineMemOperand::Flags LdMMOFlags =
+        isDereferenceable ? MachineMemOperand::MODereferenceable
+                          : MachineMemOperand::MONone;
+    if (IsNonTemporalLoad)
+      LdMMOFlags |= MachineMemOperand::MONonTemporal;
+
+    MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
+                                              ? MachineMemOperand::MONonTemporal
+                                              : MachineMemOperand::MONone;
+
+    SDValue NewLoad, NewStore;
+    if (UseVectorTy || !DoIntegerTruncate) {
+      NewLoad = DAG.getLoad(
+          JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
+          FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
+      NewStore = DAG.getStore(
+          NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
+          FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
+    } else { // This must be the truncstore/extload case
+      EVT ExtendedTy =
+          TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
+      NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
+                               FirstLoad->getChain(), FirstLoad->getBasePtr(),
+                               FirstLoad->getPointerInfo(), JointMemOpVT,
+                               FirstLoadAlign, LdMMOFlags);
+      NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
+                                   FirstInChain->getBasePtr(),
+                                   FirstInChain->getPointerInfo(), JointMemOpVT,
+                                   FirstInChain->getAlignment(),
+                                   FirstInChain->getMemOperand()->getFlags());
+    }
+
+    // Transfer chain users from old loads to the new load.
+    for (unsigned i = 0; i < NumElem; ++i) {
+      LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
+      DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
+                                    SDValue(NewLoad.getNode(), 1));
+    }
+
+    // Replace all stores with the new store. Recursively remove corresponding
+    // values if they are no longer used.
+    for (unsigned i = 0; i < NumElem; ++i) {
+      SDValue Val = StoreNodes[i].MemNode->getOperand(1);
+      CombineTo(StoreNodes[i].MemNode, NewStore);
+      if (Val.getNode()->use_empty())
+        recursivelyDeleteUnusedNodes(Val.getNode());
+    }
+
+    MadeChange = true;
+    StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+    LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
+    NumConsecutiveStores -= NumElem;
+  }
+  return MadeChange;
+}
+
 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
     return false;
@@ -16530,15 +16792,11 @@ bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
     return LHS.OffsetFromBase < RHS.OffsetFromBase;
   });
 
-  unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
       Attribute::NoImplicitFloat);
-
   bool IsNonTemporalStore = St->isNonTemporal();
   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
-  LLVMContext &Context = *DAG.getContext();
-  const DataLayout &DL = DAG.getDataLayout();
 
   // Store Merge attempts to merge the lowest stores. This generally
   // works out as if successful, as the remaining stores are checked
@@ -16559,7 +16817,7 @@ bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
     if (StoreSrc == StoreSource::Constant) {
       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
-                                             MemVT, AllowVectors, RootNode);
+                                             MemVT, RootNode, AllowVectors);
       continue;
     }
     if (StoreSrc == StoreSource::Extract) {
@@ -16567,257 +16825,12 @@ bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
                                             MemVT, RootNode);
       continue;
     }
-
-    // Below we handle the case of multiple consecutive stores that
-    // come from multiple consecutive loads. We merge them into a single
-    // wide load and a single wide store.
-    assert(StoreSrc == StoreSource::Load && "Expected load source for store");
-    int64_t StartAddress = StoreNodes[0].OffsetFromBase;
-
-    // Look for load nodes which are used by the stored values.
-    SmallVector<MemOpLink, 8> LoadNodes;
-
-    // Find acceptable loads. Loads need to have the same chain (token factor),
-    // must not be zext, volatile, indexed, and they must be consecutive.
-    BaseIndexOffset LdBasePtr;
-
-    for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
-      StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
-      SDValue Val = peekThroughBitcasts(St->getValue());
-      LoadSDNode *Ld = cast<LoadSDNode>(Val);
-
-      BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
-      // If this is not the first ptr that we check.
-      int64_t LdOffset = 0;
-      if (LdBasePtr.getBase().getNode()) {
-        // The base ptr must be the same.
-        if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
-          break;
-      } else {
-        // Check that all other base pointers are the same as this one.
-        LdBasePtr = LdPtr;
-      }
-
-      // We found a potential memory operand to merge.
-      LoadNodes.push_back(MemOpLink(Ld, LdOffset));
-    }
-
-    while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
-      // If we have load/store pair instructions and we only have two values,
-      // don't bother merging.
-      Align RequiredAlignment;
-      if (LoadNodes.size() == 2 &&
-          TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
-          StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
-        StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
-        LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
-        break;
-      }
-      LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
-      unsigned FirstStoreAS = FirstInChain->getAddressSpace();
-      unsigned FirstStoreAlign = FirstInChain->getAlignment();
-      LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
-      unsigned FirstLoadAlign = FirstLoad->getAlignment();
-
-      // Scan the memory operations on the chain and find the first
-      // non-consecutive load memory address. These variables hold the index in
-      // the store node array.
-
-      unsigned LastConsecutiveLoad = 1;
-
-      // This variable refers to the size and not index in the array.
-      unsigned LastLegalVectorType = 1;
-      unsigned LastLegalIntegerType = 1;
-      bool isDereferenceable = true;
-      bool DoIntegerTruncate = false;
-      StartAddress = LoadNodes[0].OffsetFromBase;
-      SDValue FirstChain = FirstLoad->getChain();
-      for (unsigned i = 1; i < LoadNodes.size(); ++i) {
-        // All loads must share the same chain.
-        if (LoadNodes[i].MemNode->getChain() != FirstChain)
-          break;
-
-        int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
-        if (CurrAddress - StartAddress != (ElementSizeBytes * i))
-          break;
-        LastConsecutiveLoad = i;
-
-        if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
-          isDereferenceable = false;
-
-        // Find a legal type for the vector store.
-        unsigned Elts = (i + 1) * NumMemElts;
-        EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
-
-        // Break early when size is too large to be legal.
-        if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
-          break;
-
-        bool IsFastSt = false;
-        bool IsFastLd = false;
-        if (TLI.isTypeLegal(StoreTy) &&
-            TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
-            TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                   *FirstInChain->getMemOperand(), &IsFastSt) &&
-            IsFastSt &&
-            TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                   *FirstLoad->getMemOperand(), &IsFastLd) &&
-            IsFastLd) {
-          LastLegalVectorType = i + 1;
-        }
-
-        // Find a legal type for the integer store.
-        unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
-        StoreTy = EVT::getIntegerVT(Context, SizeInBits);
-        if (TLI.isTypeLegal(StoreTy) &&
-            TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
-            TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                   *FirstInChain->getMemOperand(), &IsFastSt) &&
-            IsFastSt &&
-            TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                   *FirstLoad->getMemOperand(), &IsFastLd) &&
-            IsFastLd) {
-          LastLegalIntegerType = i + 1;
-          DoIntegerTruncate = false;
-          // Or check whether a truncstore and extload is legal.
-        } else if (TLI.getTypeAction(Context, StoreTy) ==
-                   TargetLowering::TypePromoteInteger) {
-          EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
-          if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
-              TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) &&
-              TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy,
-                                 StoreTy) &&
-              TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy,
-                                 StoreTy) &&
-              TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
-              TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                     *FirstInChain->getMemOperand(),
-                                     &IsFastSt) &&
-              IsFastSt &&
-              TLI.allowsMemoryAccess(Context, DL, StoreTy,
-                                     *FirstLoad->getMemOperand(), &IsFastLd) &&
-              IsFastLd) {
-            LastLegalIntegerType = i + 1;
-            DoIntegerTruncate = true;
-          }
-        }
-      }
-
-      // Only use vector types if the vector type is larger than the integer
-      // type. If they are the same, use integers.
-      bool UseVectorTy =
-          LastLegalVectorType > LastLegalIntegerType && AllowVectors;
-      unsigned LastLegalType =
-          std::max(LastLegalVectorType, LastLegalIntegerType);
-
-      // We add +1 here because the LastXXX variables refer to location while
-      // the NumElem refers to array/index size.
-      unsigned NumElem =
-          std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
-      NumElem = std::min(LastLegalType, NumElem);
-
-      if (NumElem < 2) {
-        // We know that candidate stores are in order and of correct
-        // shape. While there is no mergeable sequence from the
-        // beginning one may start later in the sequence. The only
-        // reason a merge of size N could have failed where another of
-        // the same size would not have is if the alignment or either
-        // the load or store has improved. Drop as many candidates as we
-        // can here.
-        unsigned NumSkip = 1;
-        while ((NumSkip < LoadNodes.size()) &&
-               (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) &&
-               (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign))
-          NumSkip++;
-        StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
-        LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
-        NumConsecutiveStores -= NumSkip;
-        continue;
-      }
-
-      // Check that we can merge these candidates without causing a cycle.
-      if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
-                                                    RootNode)) {
-        StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
-        LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
-        NumConsecutiveStores -= NumElem;
-        continue;
-      }
-
-      // Find if it is better to use vectors or integers to load and store
-      // to memory.
-      EVT JointMemOpVT;
-      if (UseVectorTy) {
-        // Find a legal type for the vector store.
-        unsigned Elts = NumElem * NumMemElts;
-        JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
-      } else {
-        unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
-        JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
-      }
-
-      SDLoc LoadDL(LoadNodes[0].MemNode);
-      SDLoc StoreDL(StoreNodes[0].MemNode);
-
-      // The merged loads are required to have the same incoming chain, so
-      // using the first's chain is acceptable.
-
-      SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
-      AddToWorklist(NewStoreChain.getNode());
-
-      MachineMemOperand::Flags LdMMOFlags =
-          isDereferenceable ? MachineMemOperand::MODereferenceable
-                            : MachineMemOperand::MONone;
-      if (IsNonTemporalLoad)
-        LdMMOFlags |= MachineMemOperand::MONonTemporal;
-
-      MachineMemOperand::Flags StMMOFlags =
-          IsNonTemporalStore ? MachineMemOperand::MONonTemporal
-                             : MachineMemOperand::MONone;
-
-      SDValue NewLoad, NewStore;
-      if (UseVectorTy || !DoIntegerTruncate) {
-        NewLoad =
-            DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
-                        FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(),
-                        FirstLoadAlign, LdMMOFlags);
-        NewStore = DAG.getStore(
-            NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
-            FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
-      } else { // This must be the truncstore/extload case
-        EVT ExtendedTy =
-            TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
-        NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
-                                 FirstLoad->getChain(), FirstLoad->getBasePtr(),
-                                 FirstLoad->getPointerInfo(), JointMemOpVT,
-                                 FirstLoadAlign, LdMMOFlags);
-        NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
-                                     FirstInChain->getBasePtr(),
-                                     FirstInChain->getPointerInfo(),
-                                     JointMemOpVT, FirstInChain->getAlignment(),
-                                     FirstInChain->getMemOperand()->getFlags());
-      }
-
-      // Transfer chain users from old loads to the new load.
-      for (unsigned i = 0; i < NumElem; ++i) {
-        LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
-        DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
-                                      SDValue(NewLoad.getNode(), 1));
-      }
-
-      // Replace all stores with the new store. Recursively remove corresponding
-      // values if they are no longer used.
-      for (unsigned i = 0; i < NumElem; ++i) {
-        SDValue Val = StoreNodes[i].MemNode->getOperand(1);
-        CombineTo(StoreNodes[i].MemNode, NewStore);
-        if (Val.getNode()->use_empty())
-          recursivelyDeleteUnusedNodes(Val.getNode());
-      }
-
-      MadeChange = true;
-      StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
-      LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
-      NumConsecutiveStores -= NumElem;
+    if (StoreSrc == StoreSource::Load) {
+      MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
+                                         MemVT, RootNode, AllowVectors,
+                                         IsNonTemporalStore,
+                                         IsNonTemporalLoad);
+      continue;
     }
   }
   return MadeChange;


        


More information about the llvm-commits mailing list