[llvm] [LowerMemIntrinsics] Lower llvm.memmove to wide memory accesses (PR #100122)
Fabian Ritter via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 24 01:05:09 PDT 2024
================
@@ -378,85 +382,229 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr,
BasicBlock *OrigBB = InsertBefore->getParent();
Function *F = OrigBB->getParent();
const DataLayout &DL = F->getDataLayout();
- // TODO: Use different element type if possible?
- Type *EltTy = Type::getInt8Ty(F->getContext());
+ LLVMContext &Ctx = OrigBB->getContext();
+ unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
+ unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+
+ Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
+ Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+ unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+ Type *Int8Type = Type::getInt8Ty(Ctx);
+ bool LoopOpIsInt8 = LoopOpType == Int8Type;
+
+ // If the memory accesses are wider than one byte, residual loops with
+ // i8-accesses are required to move remaining bytes.
+ bool RequiresResidual = !LoopOpIsInt8;
+
+ // Calculate the loop trip count and remaining bytes to copy after the loop.
+ IntegerType *ILengthType = dyn_cast<IntegerType>(TypeOfCopyLen);
+ assert(ILengthType &&
+ "expected size argument to memcpy to be an integer type!");
+ ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize);
+ ConstantInt *Zero = ConstantInt::get(ILengthType, 0);
+ ConstantInt *One = ConstantInt::get(ILengthType, 1);
+
+ IRBuilder<> PLBuilder(InsertBefore);
+
+ Value *RuntimeLoopCount = CopyLen;
+ Value *RuntimeLoopRemainder = nullptr;
+ Value *RuntimeBytesCopiedMainLoop = CopyLen;
+ Value *SkipResidualCondition = nullptr;
+ if (RequiresResidual) {
+ RuntimeLoopCount =
+ getRuntimeLoopCount(DL, PLBuilder, CopyLen, CILoopOpSize, LoopOpSize);
+ RuntimeLoopRemainder = getRuntimeLoopRemainder(DL, PLBuilder, CopyLen,
+ CILoopOpSize, LoopOpSize);
+ RuntimeBytesCopiedMainLoop =
+ PLBuilder.CreateSub(CopyLen, RuntimeLoopRemainder);
+ SkipResidualCondition =
+ PLBuilder.CreateICmpEQ(RuntimeLoopRemainder, Zero, "skip_residual");
+ }
+ Value *SkipMainCondition =
+ PLBuilder.CreateICmpEQ(RuntimeLoopCount, Zero, "skip_main");
// Create the a comparison of src and dst, based on which we jump to either
// the forward-copy part of the function (if src >= dst) or the backwards-copy
// part (if src < dst).
// SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
// structure. Its block terminators (unconditional branches) are replaced by
// the appropriate conditional branches when the loop is built.
- ICmpInst *PtrCompare = new ICmpInst(InsertBefore->getIterator(), ICmpInst::ICMP_ULT,
- SrcAddr, DstAddr, "compare_src_dst");
+ Value *PtrCompare =
+ PLBuilder.CreateICmpULT(SrcAddr, DstAddr, "compare_src_dst");
Instruction *ThenTerm, *ElseTerm;
- SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore->getIterator(), &ThenTerm,
- &ElseTerm);
-
- // Each part of the function consists of two blocks:
- // copy_backwards: used to skip the loop when n == 0
- // copy_backwards_loop: the actual backwards loop BB
- // copy_forward: used to skip the loop when n == 0
- // copy_forward_loop: the actual forward loop BB
+ SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore->getIterator(),
+ &ThenTerm, &ElseTerm);
+
+ // If the LoopOpSize is greater than 1, each part of the function consist of
+ // four blocks:
+ // memmove_copy_backwards:
+ // skip the residual loop when 0 iterations are required
+ // memmove_bwd_residual_loop:
+ // copy the last few bytes individually so that the remaining length is
+ // a multiple of the LoopOpSize
+ // memmove_bwd_middle: skip the main loop when 0 iterations are required
+ // memmove_bwd_main_loop: the actual backwards loop BB with wide accesses
+ // memmove_copy_forward: skip the main loop when 0 iterations are required
+ // memmove_fwd_main_loop: the actual forward loop BB with wide accesses
+ // memmove_fwd_middle: skip the residual loop when 0 iterations are required
+ // memmove_fwd_residual_loop: copy the last few bytes individually
+ //
+ // The main and residual loop are switched between copying forward and
+ // backward so that the residual loop always operates on the end of the moved
+ // range. This is based on the assumption that buffers whose start is aligned
+ // with the LoopOpSize are more common than buffers whose end is.
+ //
+ // If the LoopOpSize is 1, each part of the function consists of two blocks:
+ // memmove_copy_backwards: skip the loop when 0 iterations are required
+ // memmove_bwd_main_loop: the actual backwards loop BB
+ // memmove_copy_forward: skip the loop when 0 iterations are required
+ // memmove_fwd_main_loop: the actual forward loop BB
BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
- CopyBackwardsBB->setName("copy_backwards");
+ CopyBackwardsBB->setName("memmove_copy_backwards");
BasicBlock *CopyForwardBB = ElseTerm->getParent();
- CopyForwardBB->setName("copy_forward");
+ CopyForwardBB->setName("memmove_copy_forward");
BasicBlock *ExitBB = InsertBefore->getParent();
ExitBB->setName("memmove_done");
- unsigned PartSize = DL.getTypeStoreSize(EltTy);
- Align PartSrcAlign(commonAlignment(SrcAlign, PartSize));
- Align PartDstAlign(commonAlignment(DstAlign, PartSize));
-
- // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
- // between both backwards and forward copy clauses.
- ICmpInst *CompareN =
- new ICmpInst(OrigBB->getTerminator()->getIterator(), ICmpInst::ICMP_EQ, CopyLen,
- ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
+ Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize));
+ Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
// Copying backwards.
- BasicBlock *LoopBB =
- BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB);
- IRBuilder<> LoopBuilder(LoopBB);
+ {
+ BasicBlock *MainLoopBB = BasicBlock::Create(
+ F->getContext(), "memmove_bwd_main_loop", F, CopyForwardBB);
+
+ // The predecessor of the memmove_bwd_main_loop. Updated in the
+ // following if a residual loop is emitted first.
+ BasicBlock *PredBB = CopyBackwardsBB;
+
+ if (RequiresResidual) {
+ // backwards residual loop
+ BasicBlock *ResidualLoopBB = BasicBlock::Create(
+ F->getContext(), "memmove_bwd_residual_loop", F, MainLoopBB);
+ IRBuilder<> ResidualLoopBuilder(ResidualLoopBB);
+ PHINode *ResidualLoopPhi = ResidualLoopBuilder.CreatePHI(ILengthType, 0);
+ Value *ResidualIndex = ResidualLoopBuilder.CreateSub(
+ ResidualLoopPhi, One, "bwd_residual_index");
+ Value *LoadGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr,
+ ResidualIndex);
+ Value *Element = ResidualLoopBuilder.CreateLoad(Int8Type, LoadGEP,
+ SrcIsVolatile, "element");
+ Value *StoreGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
+ ResidualIndex);
+ ResidualLoopBuilder.CreateStore(Element, StoreGEP, DstIsVolatile);
+
+ // After the residual loop, go to an intermediate block.
+ BasicBlock *IntermediateBB = BasicBlock::Create(
+ F->getContext(), "memmove_bwd_middle", F, MainLoopBB);
+ // Later code expects a terminator in the PredBB.
+ IRBuilder<> IntermediateBuilder(IntermediateBB);
+ IntermediateBuilder.CreateUnreachable();
+ ResidualLoopBuilder.CreateCondBr(
+ ResidualLoopBuilder.CreateICmpEQ(ResidualIndex,
+ RuntimeBytesCopiedMainLoop),
+ IntermediateBB, ResidualLoopBB);
+
+ ResidualLoopPhi->addIncoming(ResidualIndex, ResidualLoopBB);
+ ResidualLoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
+
+ // How to get to the residual:
+ BranchInst::Create(IntermediateBB, ResidualLoopBB, SkipResidualCondition,
+ ThenTerm->getIterator());
+ ThenTerm->eraseFromParent();
+
+ PredBB = IntermediateBB;
+ }
- PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
- Value *IndexPtr = LoopBuilder.CreateSub(
- LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
- Value *Element = LoopBuilder.CreateAlignedLoad(
- EltTy, LoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, IndexPtr),
- PartSrcAlign, SrcIsVolatile, "element");
- LoopBuilder.CreateAlignedStore(
- Element, LoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, IndexPtr),
- PartDstAlign, DstIsVolatile);
- LoopBuilder.CreateCondBr(
- LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
- ExitBB, LoopBB);
- LoopPhi->addIncoming(IndexPtr, LoopBB);
- LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
- BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm->getIterator());
- ThenTerm->eraseFromParent();
+ // main loop
+ IRBuilder<> MainLoopBuilder(MainLoopBB);
+ PHINode *MainLoopPhi = MainLoopBuilder.CreatePHI(ILengthType, 0);
+ Value *MainIndex =
+ MainLoopBuilder.CreateSub(MainLoopPhi, One, "bwd_main_index");
+ Value *LoadGEP =
+ MainLoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, MainIndex);
+ Value *Element = MainLoopBuilder.CreateAlignedLoad(
+ LoopOpType, LoadGEP, PartSrcAlign, SrcIsVolatile, "element");
+ Value *StoreGEP =
+ MainLoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, MainIndex);
+ MainLoopBuilder.CreateAlignedStore(Element, StoreGEP, PartDstAlign,
+ DstIsVolatile);
+ MainLoopBuilder.CreateCondBr(MainLoopBuilder.CreateICmpEQ(MainIndex, Zero),
+ ExitBB, MainLoopBB);
+ MainLoopPhi->addIncoming(MainIndex, MainLoopBB);
+ MainLoopPhi->addIncoming(RuntimeLoopCount, PredBB);
+
+ // How to get to the main loop:
+ Instruction *PredBBTerm = PredBB->getTerminator();
+ BranchInst::Create(ExitBB, MainLoopBB, SkipMainCondition,
+ PredBBTerm->getIterator());
+ PredBBTerm->eraseFromParent();
+ }
// Copying forward.
- BasicBlock *FwdLoopBB =
- BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB);
- IRBuilder<> FwdLoopBuilder(FwdLoopBB);
- PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
- Value *SrcGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, FwdCopyPhi);
- Value *FwdElement = FwdLoopBuilder.CreateAlignedLoad(
- EltTy, SrcGEP, PartSrcAlign, SrcIsVolatile, "element");
- Value *DstGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, FwdCopyPhi);
- FwdLoopBuilder.CreateAlignedStore(FwdElement, DstGEP, PartDstAlign,
- DstIsVolatile);
- Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
- FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
- FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
- ExitBB, FwdLoopBB);
- FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
- FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
-
- BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm->getIterator());
- ElseTerm->eraseFromParent();
+ // main loop
+ {
+ BasicBlock *MainLoopBB =
+ BasicBlock::Create(F->getContext(), "memmove_fwd_main_loop", F, ExitBB);
+ IRBuilder<> MainLoopBuilder(MainLoopBB);
+ PHINode *MainLoopPhi =
+ MainLoopBuilder.CreatePHI(ILengthType, 0, "fwd_main_index");
+ Value *LoadGEP =
+ MainLoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, MainLoopPhi);
+ Value *Element = MainLoopBuilder.CreateAlignedLoad(
+ LoopOpType, LoadGEP, PartSrcAlign, SrcIsVolatile, "element");
+ Value *StoreGEP =
+ MainLoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, MainLoopPhi);
+ MainLoopBuilder.CreateAlignedStore(Element, StoreGEP, PartDstAlign,
+ DstIsVolatile);
+ Value *MainIndex = MainLoopBuilder.CreateAdd(MainLoopPhi, One);
+ MainLoopPhi->addIncoming(MainIndex, MainLoopBB);
+ MainLoopPhi->addIncoming(Zero, CopyForwardBB);
+
+ Instruction *CopyFwdBBTerm = CopyForwardBB->getTerminator();
+ BasicBlock *SuccessorBB = ExitBB;
+ if (RequiresResidual)
+ SuccessorBB =
+ BasicBlock::Create(F->getContext(), "memmove_fwd_middle", F, ExitBB);
+
+ // leaving or staying in the main loop
+ MainLoopBuilder.CreateCondBr(
+ MainLoopBuilder.CreateICmpEQ(MainIndex, RuntimeLoopCount), SuccessorBB,
+ MainLoopBB);
+
+ // getting in or skipping the main loop
+ BranchInst::Create(SuccessorBB, MainLoopBB, SkipMainCondition,
+ CopyFwdBBTerm->getIterator());
+ CopyFwdBBTerm->eraseFromParent();
+
+ if (RequiresResidual) {
+ BasicBlock *IntermediateBB = SuccessorBB;
+ IRBuilder<> IntermediateBuilder(IntermediateBB);
+ BasicBlock *ResidualLoopBB = BasicBlock::Create(
+ F->getContext(), "memmove_fwd_residual_loop", F, ExitBB);
+ IntermediateBuilder.CreateCondBr(SkipResidualCondition, ExitBB,
+ ResidualLoopBB);
+
+ // Residual loop
+ IRBuilder<> ResidualLoopBuilder(ResidualLoopBB);
+ PHINode *ResidualLoopPhi =
+ ResidualLoopBuilder.CreatePHI(ILengthType, 0, "fwd_residual_index");
+ Value *LoadGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr,
+ ResidualLoopPhi);
+ Value *Element = ResidualLoopBuilder.CreateLoad(Int8Type, LoadGEP,
+ SrcIsVolatile, "element");
+ Value *StoreGEP = ResidualLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
+ ResidualLoopPhi);
+ ResidualLoopBuilder.CreateStore(Element, StoreGEP, DstIsVolatile);
----------------
ritter-x2a wrote:
Implemented in b491cb6.
https://github.com/llvm/llvm-project/pull/100122
More information about the llvm-commits
mailing list