[llvm] [LowerMemIntrinsics] Lower llvm.memmove to wide memory accesses (PR #100122)

Fabian Ritter via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 24 01:05:09 PDT 2024


================
@@ -378,85 +382,229 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr,
   BasicBlock *OrigBB = InsertBefore->getParent();
   Function *F = OrigBB->getParent();
   const DataLayout &DL = F->getDataLayout();
-  // TODO: Use different element type if possible?
-  Type *EltTy = Type::getInt8Ty(F->getContext());
+  LLVMContext &Ctx = OrigBB->getContext();
+  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
+  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+
+  Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
+      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+  unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+  Type *Int8Type = Type::getInt8Ty(Ctx);
+  bool LoopOpIsInt8 = LoopOpType == Int8Type;
+
+  // If the memory accesses are wider than one byte, residual loops with
+  // i8-accesses are required to move remaining bytes.
+  bool RequiresResidual = !LoopOpIsInt8;
+
+  // Calculate the loop trip count and remaining bytes to copy after the loop.
+  IntegerType *ILengthType = dyn_cast<IntegerType>(TypeOfCopyLen);
+  assert(ILengthType &&
+         "expected size argument to memcpy to be an integer type!");
+  ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize);
+  ConstantInt *Zero = ConstantInt::get(ILengthType, 0);
+  ConstantInt *One = ConstantInt::get(ILengthType, 1);
+
+  IRBuilder<> PLBuilder(InsertBefore);
+
+  Value *RuntimeLoopCount = CopyLen;
+  Value *RuntimeLoopRemainder = nullptr;
+  Value *RuntimeBytesCopiedMainLoop = CopyLen;
+  Value *SkipResidualCondition = nullptr;
+  if (RequiresResidual) {
+    RuntimeLoopCount =
+        getRuntimeLoopCount(DL, PLBuilder, CopyLen, CILoopOpSize, LoopOpSize);
+    RuntimeLoopRemainder = getRuntimeLoopRemainder(DL, PLBuilder, CopyLen,
+                                                   CILoopOpSize, LoopOpSize);
+    RuntimeBytesCopiedMainLoop =
+        PLBuilder.CreateSub(CopyLen, RuntimeLoopRemainder);
+    SkipResidualCondition =
+        PLBuilder.CreateICmpEQ(RuntimeLoopRemainder, Zero, "skip_residual");
+  }
+  Value *SkipMainCondition =
+      PLBuilder.CreateICmpEQ(RuntimeLoopCount, Zero, "skip_main");
 
   // Create the a comparison of src and dst, based on which we jump to either
   // the forward-copy part of the function (if src >= dst) or the backwards-copy
   // part (if src < dst).
   // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
   // structure. Its block terminators (unconditional branches) are replaced by
   // the appropriate conditional branches when the loop is built.
-  ICmpInst *PtrCompare = new ICmpInst(InsertBefore->getIterator(), ICmpInst::ICMP_ULT,
-                                      SrcAddr, DstAddr, "compare_src_dst");
+  Value *PtrCompare =
+      PLBuilder.CreateICmpULT(SrcAddr, DstAddr, "compare_src_dst");
   Instruction *ThenTerm, *ElseTerm;
-  SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore->getIterator(), &ThenTerm,
-                                &ElseTerm);
-
-  // Each part of the function consists of two blocks:
-  //   copy_backwards:        used to skip the loop when n == 0
-  //   copy_backwards_loop:   the actual backwards loop BB
-  //   copy_forward:          used to skip the loop when n == 0
-  //   copy_forward_loop:     the actual forward loop BB
+  SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore->getIterator(),
+                                &ThenTerm, &ElseTerm);
+
+  // If the LoopOpSize is greater than 1, each part of the function consist of
+  // four blocks:
+  //   memmove_copy_backwards:
+  //       skip the residual loop when 0 iterations are required
+  //   memmove_bwd_residual_loop:
+  //       copy the last few bytes individually so that the remaining length is
+  //       a multiple of the LoopOpSize
+  //   memmove_bwd_middle: skip the main loop when 0 iterations are required
+  //   memmove_bwd_main_loop: the actual backwards loop BB with wide accesses
+  //   memmove_copy_forward: skip the main loop when 0 iterations are required
+  //   memmove_fwd_main_loop: the actual forward loop BB with wide accesses
+  //   memmove_fwd_middle: skip the residual loop when 0 iterations are required
+  //   memmove_fwd_residual_loop: copy the last few bytes individually
+  //
+  // The main and residual loop are switched between copying forward and
+  // backward so that the residual loop always operates on the end of the moved
+  // range. This is based on the assumption that buffers whose start is aligned
+  // with the LoopOpSize are more common than buffers whose end is.
+  //
+  // If the LoopOpSize is 1, each part of the function consists of two blocks:
+  //   memmove_copy_backwards: skip the loop when 0 iterations are required
+  //   memmove_bwd_main_loop: the actual backwards loop BB
+  //   memmove_copy_forward: skip the loop when 0 iterations are required
+  //   memmove_fwd_main_loop: the actual forward loop BB
   BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
-  CopyBackwardsBB->setName("copy_backwards");
+  CopyBackwardsBB->setName("memmove_copy_backwards");
   BasicBlock *CopyForwardBB = ElseTerm->getParent();
-  CopyForwardBB->setName("copy_forward");
+  CopyForwardBB->setName("memmove_copy_forward");
   BasicBlock *ExitBB = InsertBefore->getParent();
   ExitBB->setName("memmove_done");
 
-  unsigned PartSize = DL.getTypeStoreSize(EltTy);
-  Align PartSrcAlign(commonAlignment(SrcAlign, PartSize));
-  Align PartDstAlign(commonAlignment(DstAlign, PartSize));
-
-  // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
-  // between both backwards and forward copy clauses.
-  ICmpInst *CompareN =
-      new ICmpInst(OrigBB->getTerminator()->getIterator(), ICmpInst::ICMP_EQ, CopyLen,
-                   ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
+  Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize));
+  Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
 
   // Copying backwards.
-  BasicBlock *LoopBB =
-    BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
-  IRBuilder<> LoopBuilder(LoopBB);
+  {
+    BasicBlock *MainLoopBB = BasicBlock::Create(
+        F->getContext(), "memmove_bwd_main_loop", F, CopyForwardBB);
+
+    // The predecessor of the memmove_bwd_main_loop. Updated in the
+    // following if a residual loop is emitted first.
+    BasicBlock *PredBB = CopyBackwardsBB;
+
+    if (RequiresResidual) {
+      // backwards residual loop
+      BasicBlock *ResidualLoopBB = BasicBlock::Create(
+          F->getContext(), "memmove_bwd_residual_loop", F, MainLoopBB);
+      IRBuilder<> ResidualLoopBuilder(ResidualLoopBB);
+      PHINode *ResidualLoopPhi = ResidualLoopBuilder.CreatePHI(ILengthType, 0);
+      Value *ResidualIndex = ResidualLoopBuilder.CreateSub(
+          ResidualLoopPhi, One, "bwd_residual_index");
+      Value *LoadGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr,
+                                                             ResidualIndex);
+      Value *Element = ResidualLoopBuilder.CreateLoad(Int8Type, LoadGEP,
+                                                      SrcIsVolatile, "element");
+      Value *StoreGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
+                                                              ResidualIndex);
+      ResidualLoopBuilder.CreateStore(Element, StoreGEP, DstIsVolatile);
+
+      // After the residual loop, go to an intermediate block.
+      BasicBlock *IntermediateBB = BasicBlock::Create(
+          F->getContext(), "memmove_bwd_middle", F, MainLoopBB);
+      // Later code expects a terminator in the PredBB.
+      IRBuilder<> IntermediateBuilder(IntermediateBB);
+      IntermediateBuilder.CreateUnreachable();
+      ResidualLoopBuilder.CreateCondBr(
+          ResidualLoopBuilder.CreateICmpEQ(ResidualIndex,
+                                           RuntimeBytesCopiedMainLoop),
+          IntermediateBB, ResidualLoopBB);
+
+      ResidualLoopPhi->addIncoming(ResidualIndex, ResidualLoopBB);
+      ResidualLoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
+
+      // How to get to the residual:
+      BranchInst::Create(IntermediateBB, ResidualLoopBB, SkipResidualCondition,
+                         ThenTerm->getIterator());
+      ThenTerm->eraseFromParent();
+
+      PredBB = IntermediateBB;
+    }
 
-  PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
-  Value *IndexPtr = LoopBuilder.CreateSub(
-      LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
-  Value *Element = LoopBuilder.CreateAlignedLoad(
-      EltTy, LoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, IndexPtr),
-      PartSrcAlign, SrcIsVolatile, "element");
-  LoopBuilder.CreateAlignedStore(
-      Element, LoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, IndexPtr),
-      PartDstAlign, DstIsVolatile);
-  LoopBuilder.CreateCondBr(
-      LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
-      ExitBB, LoopBB);
-  LoopPhi->addIncoming(IndexPtr, LoopBB);
-  LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
-  BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm->getIterator());
-  ThenTerm->eraseFromParent();
+    // main loop
+    IRBuilder<> MainLoopBuilder(MainLoopBB);
+    PHINode *MainLoopPhi = MainLoopBuilder.CreatePHI(ILengthType, 0);
+    Value *MainIndex =
+        MainLoopBuilder.CreateSub(MainLoopPhi, One, "bwd_main_index");
+    Value *LoadGEP =
+        MainLoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, MainIndex);
+    Value *Element = MainLoopBuilder.CreateAlignedLoad(
+        LoopOpType, LoadGEP, PartSrcAlign, SrcIsVolatile, "element");
+    Value *StoreGEP =
+        MainLoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, MainIndex);
+    MainLoopBuilder.CreateAlignedStore(Element, StoreGEP, PartDstAlign,
+                                       DstIsVolatile);
+    MainLoopBuilder.CreateCondBr(MainLoopBuilder.CreateICmpEQ(MainIndex, Zero),
+                                 ExitBB, MainLoopBB);
+    MainLoopPhi->addIncoming(MainIndex, MainLoopBB);
+    MainLoopPhi->addIncoming(RuntimeLoopCount, PredBB);
+
+    // How to get to the main loop:
+    Instruction *PredBBTerm = PredBB->getTerminator();
+    BranchInst::Create(ExitBB, MainLoopBB, SkipMainCondition,
+                       PredBBTerm->getIterator());
+    PredBBTerm->eraseFromParent();
+  }
 
   // Copying forward.
-  BasicBlock *FwdLoopBB =
-    BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB);
-  IRBuilder<> FwdLoopBuilder(FwdLoopBB);
-  PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
-  Value *SrcGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, FwdCopyPhi);
-  Value *FwdElement = FwdLoopBuilder.CreateAlignedLoad(
-      EltTy, SrcGEP, PartSrcAlign, SrcIsVolatile, "element");
-  Value *DstGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, FwdCopyPhi);
-  FwdLoopBuilder.CreateAlignedStore(FwdElement, DstGEP, PartDstAlign,
-                                    DstIsVolatile);
-  Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
-      FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
-  FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
-                              ExitBB, FwdLoopBB);
-  FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
-  FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
-
-  BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm->getIterator());
-  ElseTerm->eraseFromParent();
+  // main loop
+  {
+    BasicBlock *MainLoopBB =
+        BasicBlock::Create(F->getContext(), "memmove_fwd_main_loop", F, ExitBB);
+    IRBuilder<> MainLoopBuilder(MainLoopBB);
+    PHINode *MainLoopPhi =
+        MainLoopBuilder.CreatePHI(ILengthType, 0, "fwd_main_index");
+    Value *LoadGEP =
+        MainLoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, MainLoopPhi);
+    Value *Element = MainLoopBuilder.CreateAlignedLoad(
+        LoopOpType, LoadGEP, PartSrcAlign, SrcIsVolatile, "element");
+    Value *StoreGEP =
+        MainLoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, MainLoopPhi);
+    MainLoopBuilder.CreateAlignedStore(Element, StoreGEP, PartDstAlign,
+                                       DstIsVolatile);
+    Value *MainIndex = MainLoopBuilder.CreateAdd(MainLoopPhi, One);
+    MainLoopPhi->addIncoming(MainIndex, MainLoopBB);
+    MainLoopPhi->addIncoming(Zero, CopyForwardBB);
+
+    Instruction *CopyFwdBBTerm = CopyForwardBB->getTerminator();
+    BasicBlock *SuccessorBB = ExitBB;
+    if (RequiresResidual)
+      SuccessorBB =
+          BasicBlock::Create(F->getContext(), "memmove_fwd_middle", F, ExitBB);
+
+    // leaving or staying in the main loop
+    MainLoopBuilder.CreateCondBr(
+        MainLoopBuilder.CreateICmpEQ(MainIndex, RuntimeLoopCount), SuccessorBB,
+        MainLoopBB);
+
+    // getting in or skipping the main loop
+    BranchInst::Create(SuccessorBB, MainLoopBB, SkipMainCondition,
+                       CopyFwdBBTerm->getIterator());
+    CopyFwdBBTerm->eraseFromParent();
+
+    if (RequiresResidual) {
+      BasicBlock *IntermediateBB = SuccessorBB;
+      IRBuilder<> IntermediateBuilder(IntermediateBB);
+      BasicBlock *ResidualLoopBB = BasicBlock::Create(
+          F->getContext(), "memmove_fwd_residual_loop", F, ExitBB);
+      IntermediateBuilder.CreateCondBr(SkipResidualCondition, ExitBB,
+                                       ResidualLoopBB);
+
+      // Residual loop
+      IRBuilder<> ResidualLoopBuilder(ResidualLoopBB);
+      PHINode *ResidualLoopPhi =
+          ResidualLoopBuilder.CreatePHI(ILengthType, 0, "fwd_residual_index");
+      Value *LoadGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr,
+                                                             ResidualLoopPhi);
+      Value *Element = ResidualLoopBuilder.CreateLoad(Int8Type, LoadGEP,
+                                                      SrcIsVolatile, "element");
+      Value *StoreGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
+                                                              ResidualLoopPhi);
+      ResidualLoopBuilder.CreateStore(Element, StoreGEP, DstIsVolatile);
----------------
ritter-x2a wrote:

Implemented in b491cb6.

https://github.com/llvm/llvm-project/pull/100122


More information about the llvm-commits mailing list