[llvm] r316417 - [CodeGen][ExpandMemcmp][NFC] Allow memcmp to expand to vector loads (1)

Clement Courbet via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 24 01:05:07 PDT 2017


Author: courbet
Date: Tue Oct 24 01:05:07 2017
New Revision: 316417

URL: http://llvm.org/viewvc/llvm-project?rev=316417&view=rev
Log:
[CodeGen][ExpandMemcmp][NFC] Allow memcmp to expand to vector loads (1)

Refactor ExpandMemcmp:

 - Stop duplicating the logic for computation of the sequence of loads to
   generate (thsi was done in three different places), this is now done
   only once in MemCmpExpansion::MemCmpExpansion().

 - Add a FIXME to expose a bug with the computation of the number of loads
   when not all sizes are loadable. For example, on X86-32 + SSE, possible
   loads are {16,4,2,1} bytes. The current code considers that all loads
   starting at MaxLoadSize are possible. This is not an issue right now as
   vector loads are not enabled, so I'm not fixing the issue here to keep
   the change as small as possible. I'm going to address this in a
   subsequent revision, where I enable vector loads.

See https://bugs.llvm.org/show_bug.cgi?id=34887

Differential Revision: https://reviews.llvm.org/D38498

Modified:
    llvm/trunk/lib/CodeGen/CodeGenPrepare.cpp

Modified: llvm/trunk/lib/CodeGen/CodeGenPrepare.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/CodeGenPrepare.cpp?rev=316417&r1=316416&r2=316417&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/CodeGenPrepare.cpp (original)
+++ llvm/trunk/lib/CodeGen/CodeGenPrepare.cpp Tue Oct 24 01:05:07 2017
@@ -1710,43 +1710,61 @@ class MemCmpExpansion {
     ResultBlock() = default;
   };
 
-  CallInst *CI;
+  CallInst *const CI;
   ResultBlock ResBlock;
+  const unsigned Size;
   unsigned MaxLoadSize;
-  unsigned NumBlocks;
-  unsigned NumBlocksNonOneByte;
-  unsigned NumLoadsPerBlock;
+  unsigned NumLoadsNonOneByte;
+  const unsigned NumLoadsPerBlock;
   std::vector<BasicBlock *> LoadCmpBlocks;
   BasicBlock *EndBlock;
   PHINode *PhiRes;
-  bool IsUsedForZeroCmp;
+  const bool IsUsedForZeroCmp;
   const DataLayout &DL;
   IRBuilder<> Builder;
+  // Represents the decomposition in blocks of the expansion. For example,
+  // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
+  // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
+  // TODO(courbet): Involve the target more in this computation. On X86, 7
+  // bytes can be done more efficiently with two overlaping 4-byte loads than
+  // covering the interval with [{4, 0},{2, 4},{1, 6}}.
+  struct LoadEntry {
+    LoadEntry(unsigned LoadSize, uint64_t Offset)
+        : LoadSize(LoadSize), Offset(Offset) {
+      assert(Offset % LoadSize == 0 && "invalid load entry");
+    }
+
+    uint64_t getGEPIndex() const { return Offset / LoadSize; }
+
+    // The size of the load for this block, in bytes.
+    const unsigned LoadSize;
+    // The offset of this load WRT the base pointer, in bytes.
+    const uint64_t Offset;
+  };
+  SmallVector<LoadEntry, 8> LoadSequence;
 
-  unsigned calculateNumBlocks(unsigned Size);
   void createLoadCmpBlocks();
   void createResultBlock();
   void setupResultBlockPHINodes();
   void setupEndBlockPHINodes();
-  void emitLoadCompareBlock(unsigned Index, unsigned LoadSize,
-                            unsigned GEPIndex);
-  Value *getCompareLoadPairs(unsigned Index, unsigned Size,
-                             unsigned &NumBytesProcessed);
-  void emitLoadCompareBlockMultipleLoads(unsigned Index, unsigned Size,
-                                         unsigned &NumBytesProcessed);
-  void emitLoadCompareByteBlock(unsigned Index, unsigned GEPIndex);
+  Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
+  void emitLoadCompareBlock(unsigned BlockIndex);
+  void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
+                                         unsigned &LoadIndex);
+  void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
   void emitMemCmpResultBlock();
-  Value *getMemCmpExpansionZeroCase(unsigned Size);
-  Value *getMemCmpEqZeroOneBlock(unsigned Size);
-  Value *getMemCmpOneBlock(unsigned Size);
-  unsigned getLoadSize(unsigned Size);
-  unsigned getNumLoads(unsigned Size);
+  Value *getMemCmpExpansionZeroCase();
+  Value *getMemCmpEqZeroOneBlock();
+  Value *getMemCmpOneBlock();
 
-public:
+ public:
   MemCmpExpansion(CallInst *CI, uint64_t Size, unsigned MaxLoadSize,
                   unsigned NumLoadsPerBlock, const DataLayout &DL);
 
-  Value *getMemCmpExpansion(uint64_t Size);
+  unsigned getNumBlocks();
+  unsigned getNumLoads() const { return LoadSequence.size(); }
+
+  Value *getMemCmpExpansion();
 };
 
 } // end anonymous namespace
@@ -1759,43 +1777,56 @@ public:
 // return from.
 // 3. ResultBlock, block to branch to for early exit when a
 // LoadCmpBlock finds a difference.
-MemCmpExpansion::MemCmpExpansion(CallInst *CI, uint64_t Size,
-                                 unsigned MaxLoadSize, unsigned LoadsPerBlock,
+MemCmpExpansion::MemCmpExpansion(CallInst *const CI, uint64_t Size,
+                                 const unsigned MaxLoadSize,
+                                 const unsigned LoadsPerBlock,
                                  const DataLayout &TheDataLayout)
-    : CI(CI), MaxLoadSize(MaxLoadSize), NumLoadsPerBlock(LoadsPerBlock),
-      DL(TheDataLayout), Builder(CI) {
-  // A memcmp with zero-comparison with only one block of load and compare does
-  // not need to set up any extra blocks. This case could be handled in the DAG,
-  // but since we have all of the machinery to flexibly expand any memcpy here,
-  // we choose to handle this case too to avoid fragmented lowering.
-  IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
-  NumBlocks = calculateNumBlocks(Size);
-  if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || NumBlocks != 1) {
-    BasicBlock *StartBlock = CI->getParent();
-    EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
-    setupEndBlockPHINodes();
-    createResultBlock();
-
-    // If return value of memcmp is not used in a zero equality, we need to
-    // calculate which source was larger. The calculation requires the
-    // two loaded source values of each load compare block.
-    // These will be saved in the phi nodes created by setupResultBlockPHINodes.
-    if (!IsUsedForZeroCmp)
-      setupResultBlockPHINodes();
-
-    // Create the number of required load compare basic blocks.
-    createLoadCmpBlocks();
-
-    // Update the terminator added by splitBasicBlock to branch to the first
-    // LoadCmpBlock.
-    StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
+    : CI(CI),
+      Size(Size),
+      MaxLoadSize(MaxLoadSize),
+      NumLoadsNonOneByte(0),
+      NumLoadsPerBlock(LoadsPerBlock),
+      IsUsedForZeroCmp(isOnlyUsedInZeroEqualityComparison(CI)),
+      DL(TheDataLayout),
+      Builder(CI) {
+  // Scale the max size down if the target can load more bytes than we need.
+  while (this->MaxLoadSize > Size) {
+    this->MaxLoadSize /= 2;
   }
+  // Compute the decomposition.
+  unsigned LoadSize = this->MaxLoadSize;
+  assert(Size > 0 && "zero blocks");
+  uint64_t Offset = 0;
+  while (Size) {
+    assert(LoadSize > 0 && "zero load size");
+    const uint64_t NumLoadsForThisSize = Size / LoadSize;
+    if (NumLoadsForThisSize > 0) {
+      for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
+        LoadSequence.push_back({LoadSize, Offset});
+        Offset += LoadSize;
+      }
+      if (LoadSize > 1) {
+        ++NumLoadsNonOneByte;
+      }
+      Size = Size % LoadSize;
+    }
+    // FIXME: This can result in a non-native load size (e.g. X86-32+SSE can
+    // load 16 and 4 but not 8), which throws the load count off (e.g. in the
+    // aforementioned case, 16 bytes will count for 2 loads but will generate
+    // 4).
+    LoadSize /= 2;
+  }
+}
 
-  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+unsigned MemCmpExpansion::getNumBlocks() {
+  if (IsUsedForZeroCmp)
+    return getNumLoads() / NumLoadsPerBlock +
+           (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0);
+  return getNumLoads();
 }
 
 void MemCmpExpansion::createLoadCmpBlocks() {
-  for (unsigned i = 0; i < NumBlocks; i++) {
+  for (unsigned i = 0; i < getNumBlocks(); i++) {
     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
                                         EndBlock->getParent(), EndBlock);
     LoadCmpBlocks.push_back(BB);
@@ -1811,12 +1842,12 @@ void MemCmpExpansion::createResultBlock(
 // It loads 1 byte from each source of the memcmp parameters with the given
 // GEPIndex. It then subtracts the two loaded values and adds this result to the
 // final phi node for selecting the memcmp result.
-void MemCmpExpansion::emitLoadCompareByteBlock(unsigned Index,
+void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
                                                unsigned GEPIndex) {
   Value *Source1 = CI->getArgOperand(0);
   Value *Source2 = CI->getArgOperand(1);
 
-  Builder.SetInsertPoint(LoadCmpBlocks[Index]);
+  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
   // Cast source to LoadSizeType*.
   if (Source1->getType() != LoadSizeType)
@@ -1839,15 +1870,15 @@ void MemCmpExpansion::emitLoadCompareByt
   LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
   Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
 
-  PhiRes->addIncoming(Diff, LoadCmpBlocks[Index]);
+  PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
 
-  if (Index < (LoadCmpBlocks.size() - 1)) {
+  if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
     // Early exit branch if difference found to EndBlock. Otherwise, continue to
     // next LoadCmpBlock,
     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
                                     ConstantInt::get(Diff->getType(), 0));
     BranchInst *CmpBr =
-        BranchInst::Create(EndBlock, LoadCmpBlocks[Index + 1], Cmp);
+        BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
     Builder.Insert(CmpBr);
   } else {
     // The last block has an unconditional branch to EndBlock.
@@ -1856,42 +1887,37 @@ void MemCmpExpansion::emitLoadCompareByt
   }
 }
 
-unsigned MemCmpExpansion::getNumLoads(unsigned Size) {
-  return (Size / MaxLoadSize) + countPopulation(Size % MaxLoadSize);
-}
-
-unsigned MemCmpExpansion::getLoadSize(unsigned Size) {
-  return MinAlign(PowerOf2Floor(Size), MaxLoadSize);
-}
-
 /// Generate an equality comparison for one or more pairs of loaded values.
 /// This is used in the case where the memcmp() call is compared equal or not
 /// equal to zero.
-Value *MemCmpExpansion::getCompareLoadPairs(unsigned Index, unsigned Size,
-                                            unsigned &NumBytesProcessed) {
+Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
+                                            unsigned &LoadIndex) {
+  assert(LoadIndex < getNumLoads() &&
+         "getCompareLoadPairs() called with no remaining loads");
   std::vector<Value *> XorList, OrList;
   Value *Diff;
 
-  unsigned RemainingBytes = Size - NumBytesProcessed;
-  unsigned NumLoadsRemaining = getNumLoads(RemainingBytes);
-  unsigned NumLoads = std::min(NumLoadsRemaining, NumLoadsPerBlock);
+  const unsigned NumLoads =
+      std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock);
 
   // For a single-block expansion, start inserting before the memcmp call.
   if (LoadCmpBlocks.empty())
     Builder.SetInsertPoint(CI);
   else
-    Builder.SetInsertPoint(LoadCmpBlocks[Index]);
+    Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
 
   Value *Cmp = nullptr;
-  for (unsigned i = 0; i < NumLoads; ++i) {
-    unsigned LoadSize = getLoadSize(RemainingBytes);
-    unsigned GEPIndex = NumBytesProcessed / LoadSize;
-    NumBytesProcessed += LoadSize;
-    RemainingBytes -= LoadSize;
-
-    Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8);
-    Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
-    assert(LoadSize <= MaxLoadSize && "Unexpected load type");
+  // If we have multiple loads per block, we need to generate a composite
+  // comparison using xor+or. The type for the combinations is the largest load
+  // type.
+  IntegerType *const MaxLoadType =
+      NumLoads == 1 ? nullptr
+                    : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
+  for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
+    const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
+
+    IntegerType *LoadSizeType =
+        IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
 
     Value *Source1 = CI->getArgOperand(0);
     Value *Source2 = CI->getArgOperand(1);
@@ -1902,12 +1928,14 @@ Value *MemCmpExpansion::getCompareLoadPa
     if (Source2->getType() != LoadSizeType)
       Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
 
-    // Get the base address using the GEPIndex.
-    if (GEPIndex != 0) {
-      Source1 = Builder.CreateGEP(LoadSizeType, Source1,
-                                  ConstantInt::get(LoadSizeType, GEPIndex));
-      Source2 = Builder.CreateGEP(LoadSizeType, Source2,
-                                  ConstantInt::get(LoadSizeType, GEPIndex));
+    // Get the base address using a GEP.
+    if (CurLoadEntry.Offset != 0) {
+      Source1 = Builder.CreateGEP(
+          LoadSizeType, Source1,
+          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
+      Source2 = Builder.CreateGEP(
+          LoadSizeType, Source2,
+          ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
     }
 
     // Get a constant or load a value for each source address.
@@ -1964,13 +1992,13 @@ Value *MemCmpExpansion::getCompareLoadPa
   return Cmp;
 }
 
-void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(
-    unsigned Index, unsigned Size, unsigned &NumBytesProcessed) {
-  Value *Cmp = getCompareLoadPairs(Index, Size, NumBytesProcessed);
+void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
+                                                        unsigned &LoadIndex) {
+  Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
 
-  BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1))
+  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
                            ? EndBlock
-                           : LoadCmpBlocks[Index + 1];
+                           : LoadCmpBlocks[BlockIndex + 1];
   // Early exit branch if difference found to ResultBlock. Otherwise,
   // continue to next LoadCmpBlock or EndBlock.
   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
@@ -1979,9 +2007,9 @@ void MemCmpExpansion::emitLoadCompareBlo
   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
   // since early exit to ResultBlock was not taken (no difference was found in
   // any of the bytes).
-  if (Index == LoadCmpBlocks.size() - 1) {
+  if (BlockIndex == LoadCmpBlocks.size() - 1) {
     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
-    PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]);
+    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
   }
 }
 
@@ -1994,33 +2022,39 @@ void MemCmpExpansion::emitLoadCompareBlo
 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
 // a special case through emitLoadCompareByteBlock. The special handling can
 // simply subtract the loaded values and add it to the result phi node.
-void MemCmpExpansion::emitLoadCompareBlock(unsigned Index, unsigned LoadSize,
-                                           unsigned GEPIndex) {
-  if (LoadSize == 1) {
-    MemCmpExpansion::emitLoadCompareByteBlock(Index, GEPIndex);
+void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
+  // There is one load per block in this case, BlockIndex == LoadIndex.
+  const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
+
+  if (CurLoadEntry.LoadSize == 1) {
+    MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
+                                              CurLoadEntry.getGEPIndex());
     return;
   }
 
-  Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8);
+  Type *LoadSizeType =
+      IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
-  assert(LoadSize <= MaxLoadSize && "Unexpected load type");
+  assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
 
   Value *Source1 = CI->getArgOperand(0);
   Value *Source2 = CI->getArgOperand(1);
 
-  Builder.SetInsertPoint(LoadCmpBlocks[Index]);
+  Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
   // Cast source to LoadSizeType*.
   if (Source1->getType() != LoadSizeType)
     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
   if (Source2->getType() != LoadSizeType)
     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
 
-  // Get the base address using the GEPIndex.
-  if (GEPIndex != 0) {
-    Source1 = Builder.CreateGEP(LoadSizeType, Source1,
-                                ConstantInt::get(LoadSizeType, GEPIndex));
-    Source2 = Builder.CreateGEP(LoadSizeType, Source2,
-                                ConstantInt::get(LoadSizeType, GEPIndex));
+  // Get the base address using a GEP.
+  if (CurLoadEntry.Offset != 0) {
+    Source1 = Builder.CreateGEP(
+        LoadSizeType, Source1,
+        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
+    Source2 = Builder.CreateGEP(
+        LoadSizeType, Source2,
+        ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
   }
 
   // Load LoadSizeType from the base address.
@@ -2042,14 +2076,14 @@ void MemCmpExpansion::emitLoadCompareBlo
   // Add the loaded values to the phi nodes for calculating memcmp result only
   // if result is not used in a zero equality.
   if (!IsUsedForZeroCmp) {
-    ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[Index]);
-    ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[Index]);
+    ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
+    ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
   }
 
   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
-  BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1))
+  BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
                            ? EndBlock
-                           : LoadCmpBlocks[Index + 1];
+                           : LoadCmpBlocks[BlockIndex + 1];
   // Early exit branch if difference found to ResultBlock. Otherwise, continue
   // to next LoadCmpBlock or EndBlock.
   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
@@ -2058,9 +2092,9 @@ void MemCmpExpansion::emitLoadCompareBlo
   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
   // since early exit to ResultBlock was not taken (no difference was found in
   // any of the bytes).
-  if (Index == LoadCmpBlocks.size() - 1) {
+  if (BlockIndex == LoadCmpBlocks.size() - 1) {
     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
-    PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]);
+    PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
   }
 }
 
@@ -2094,34 +2128,14 @@ void MemCmpExpansion::emitMemCmpResultBl
   PhiRes->addIncoming(Res, ResBlock.BB);
 }
 
-unsigned MemCmpExpansion::calculateNumBlocks(unsigned Size) {
-  unsigned NumBlocks = 0;
-  bool HaveOneByteLoad = false;
-  unsigned RemainingSize = Size;
-  unsigned LoadSize = MaxLoadSize;
-  while (RemainingSize) {
-    if (LoadSize == 1)
-      HaveOneByteLoad = true;
-    NumBlocks += RemainingSize / LoadSize;
-    RemainingSize = RemainingSize % LoadSize;
-    LoadSize = LoadSize / 2;
-  }
-  NumBlocksNonOneByte = HaveOneByteLoad ? (NumBlocks - 1) : NumBlocks;
-
-  if (IsUsedForZeroCmp)
-    NumBlocks = NumBlocks / NumLoadsPerBlock +
-                (NumBlocks % NumLoadsPerBlock != 0 ? 1 : 0);
-
-  return NumBlocks;
-}
-
 void MemCmpExpansion::setupResultBlockPHINodes() {
   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
   Builder.SetInsertPoint(ResBlock.BB);
+  // Note: this assumes one load per block.
   ResBlock.PhiSrc1 =
-      Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src1");
+      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
   ResBlock.PhiSrc2 =
-      Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src2");
+      Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
 }
 
 void MemCmpExpansion::setupEndBlockPHINodes() {
@@ -2129,12 +2143,13 @@ void MemCmpExpansion::setupEndBlockPHINo
   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
 }
 
-Value *MemCmpExpansion::getMemCmpExpansionZeroCase(unsigned Size) {
-  unsigned NumBytesProcessed = 0;
+Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
+  unsigned LoadIndex = 0;
   // This loop populates each of the LoadCmpBlocks with the IR sequence to
   // handle multiple loads per block.
-  for (unsigned i = 0; i < NumBlocks; ++i)
-    emitLoadCompareBlockMultipleLoads(i, Size, NumBytesProcessed);
+  for (unsigned I = 0; I < getNumBlocks(); ++I) {
+    emitLoadCompareBlockMultipleLoads(I, LoadIndex);
+  }
 
   emitMemCmpResultBlock();
   return PhiRes;
@@ -2143,15 +2158,16 @@ Value *MemCmpExpansion::getMemCmpExpansi
 /// A memcmp expansion that compares equality with 0 and only has one block of
 /// load and compare can bypass the compare, branch, and phi IR that is required
 /// in the general case.
-Value *MemCmpExpansion::getMemCmpEqZeroOneBlock(unsigned Size) {
-  unsigned NumBytesProcessed = 0;
-  Value *Cmp = getCompareLoadPairs(0, Size, NumBytesProcessed);
+Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
+  unsigned LoadIndex = 0;
+  Value *Cmp = getCompareLoadPairs(0, LoadIndex);
+  assert(LoadIndex == getNumLoads() && "some entries were not consumed");
   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
 }
 
 /// A memcmp expansion that only has one block of load and compare can bypass
 /// the compare, branch, and phi IR that is required in the general case.
-Value *MemCmpExpansion::getMemCmpOneBlock(unsigned Size) {
+Value *MemCmpExpansion::getMemCmpOneBlock() {
   assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block");
 
   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
@@ -2198,37 +2214,42 @@ Value *MemCmpExpansion::getMemCmpOneBloc
 
 // This function expands the memcmp call into an inline expansion and returns
 // the memcmp result.
-Value *MemCmpExpansion::getMemCmpExpansion(uint64_t Size) {
+Value *MemCmpExpansion::getMemCmpExpansion() {
+  // A memcmp with zero-comparison with only one block of load and compare does
+  // not need to set up any extra blocks. This case could be handled in the DAG,
+  // but since we have all of the machinery to flexibly expand any memcpy here,
+  // we choose to handle this case too to avoid fragmented lowering.
+  if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) {
+    BasicBlock *StartBlock = CI->getParent();
+    EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
+    setupEndBlockPHINodes();
+    createResultBlock();
+
+    // If return value of memcmp is not used in a zero equality, we need to
+    // calculate which source was larger. The calculation requires the
+    // two loaded source values of each load compare block.
+    // These will be saved in the phi nodes created by setupResultBlockPHINodes.
+    if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
+
+    // Create the number of required load compare basic blocks.
+    createLoadCmpBlocks();
+
+    // Update the terminator added by splitBasicBlock to branch to the first
+    // LoadCmpBlock.
+    StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
+  }
+
+  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
   if (IsUsedForZeroCmp)
-    return NumBlocks == 1 ? getMemCmpEqZeroOneBlock(Size) :
-                            getMemCmpExpansionZeroCase(Size);
+    return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
+                               : getMemCmpExpansionZeroCase();
 
   // TODO: Handle more than one load pair per block in getMemCmpOneBlock().
-  if (NumBlocks == 1 && NumLoadsPerBlock == 1)
-    return getMemCmpOneBlock(Size);
+  if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock();
 
-  // This loop calls emitLoadCompareBlock for comparing Size bytes of the two
-  // memcmp sources. It starts with loading using the maximum load size set by
-  // the target. It processes any remaining bytes using a load size which is the
-  // next smallest power of 2.
-  unsigned LoadSize = MaxLoadSize;
-  unsigned NumBytesToBeProcessed = Size;
-  unsigned Index = 0;
-  while (NumBytesToBeProcessed) {
-    // Calculate how many blocks we can create with the current load size.
-    unsigned NumBlocks = NumBytesToBeProcessed / LoadSize;
-    unsigned GEPIndex = (Size - NumBytesToBeProcessed) / LoadSize;
-    NumBytesToBeProcessed = NumBytesToBeProcessed % LoadSize;
-
-    // For each NumBlocks, populate the instruction sequence for loading and
-    // comparing LoadSize bytes.
-    while (NumBlocks--) {
-      emitLoadCompareBlock(Index, LoadSize, GEPIndex);
-      Index++;
-      GEPIndex++;
-    }
-    // Get the next LoadSize to use.
-    LoadSize = LoadSize / 2;
+  for (unsigned I = 0; I < getNumBlocks(); ++I) {
+    emitLoadCompareBlock(I);
   }
 
   emitMemCmpResultBlock();
@@ -2312,12 +2333,6 @@ static bool expandMemCmp(CallInst *CI, c
                          const TargetLowering *TLI, const DataLayout *DL) {
   NumMemCmpCalls++;
 
-  // TTI call to check if target would like to expand memcmp. Also, get the
-  // MaxLoadSize.
-  unsigned MaxLoadSize;
-  if (!TTI->enableMemCmpExpansion(MaxLoadSize))
-    return false;
-
   // Early exit from expansion if -Oz.
   if (CI->getFunction()->optForMinSize())
     return false;
@@ -2328,36 +2343,26 @@ static bool expandMemCmp(CallInst *CI, c
     NumMemCmpNotConstant++;
     return false;
   }
+  const uint64_t SizeVal = SizeCast->getZExtValue();
 
-  // Scale the max size down if the target can load more bytes than we need.
-  uint64_t SizeVal = SizeCast->getZExtValue();
-  if (MaxLoadSize > SizeVal)
-    MaxLoadSize = 1 << SizeCast->getValue().logBase2();
-
-  // Calculate how many load pairs are needed for the constant size.
-  unsigned NumLoads = 0;
-  unsigned RemainingSize = SizeVal;
-  unsigned LoadSize = MaxLoadSize;
-  while (RemainingSize) {
-    NumLoads += RemainingSize / LoadSize;
-    RemainingSize = RemainingSize % LoadSize;
-    LoadSize = LoadSize / 2;
-  }
+  // TTI call to check if target would like to expand memcmp. Also, get the
+  // max LoadSize.
+  unsigned MaxLoadSize;
+  if (!TTI->enableMemCmpExpansion(MaxLoadSize)) return false;
+
+  MemCmpExpansion Expansion(CI, SizeVal, MaxLoadSize, MemCmpNumLoadsPerBlock,
+                            *DL);
 
   // Don't expand if this will require more loads than desired by the target.
-  if (NumLoads > TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) {
+  if (Expansion.getNumLoads() >
+      TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) {
     NumMemCmpGreaterThanMax++;
     return false;
   }
 
   NumMemCmpInlined++;
 
-  // MemCmpHelper object creates and sets up basic blocks required for
-  // expanding memcmp with size SizeVal.
-  unsigned NumLoadsPerBlock = MemCmpNumLoadsPerBlock;
-  MemCmpExpansion MemCmpHelper(CI, SizeVal, MaxLoadSize, NumLoadsPerBlock, *DL);
-
-  Value *Res = MemCmpHelper.getMemCmpExpansion(SizeVal);
+  Value *Res = Expansion.getMemCmpExpansion();
 
   // Replace call with result of expansion and erase call.
   CI->replaceAllUsesWith(Res);




More information about the llvm-commits mailing list