[llvm] da41214 - Add support for atomic memory copy lowering

Evgeniy Brevnov via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 7 20:41:38 PDT 2022


Author: Evgeniy Brevnov
Date: 2022-04-08T10:41:31+07:00
New Revision: da41214d653808db86bcefeb97da842012ebd104

URL: https://github.com/llvm/llvm-project/commit/da41214d653808db86bcefeb97da842012ebd104
DIFF: https://github.com/llvm/llvm-project/commit/da41214d653808db86bcefeb97da842012ebd104.diff

LOG: Add support for atomic memory copy lowering

Currently, the utility supports lowering of non atomic memory transfer routines only. This patch adds support for atomic version of memcopy. This may be useful for targets not supporting atomic memcopy.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D118443

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
    llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
    llvm/unittests/Transforms/Utils/MemTransferLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f9e44fc527cda..acda8a0e60229 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1290,9 +1290,11 @@ class TargetTransformInfo {
                                            Type *ExpectedType) const;
 
   /// \returns The type to use in a loop expansion of a memcpy call.
-  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const;
+  Type *
+  getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
+                            unsigned SrcAddrSpace, unsigned DestAddrSpace,
+                            unsigned SrcAlign, unsigned DestAlign,
+                            Optional<uint32_t> AtomicElementSize = None) const;
 
   /// \param[out] OpsOut The operand types to copy RemainingBytes of memory.
   /// \param RemainingBytes The number of bytes to copy.
@@ -1303,7 +1305,8 @@ class TargetTransformInfo {
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const;
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize = None) const;
 
   /// \returns True if the two functions have compatible attributes for inlining
   /// purposes.
@@ -1745,15 +1748,17 @@ class TargetTransformInfo::Concept {
   virtual unsigned getAtomicMemIntrinsicMaxElementSize() const = 0;
   virtual Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
                                                    Type *ExpectedType) = 0;
-  virtual Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                          unsigned SrcAddrSpace,
-                                          unsigned DestAddrSpace,
-                                          unsigned SrcAlign,
-                                          unsigned DestAlign) const = 0;
+  virtual Type *
+  getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
+                            unsigned SrcAddrSpace, unsigned DestAddrSpace,
+                            unsigned SrcAlign, unsigned DestAlign,
+                            Optional<uint32_t> AtomicElementSize) const = 0;
+
   virtual void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const = 0;
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const = 0;
   virtual bool areInlineCompatible(const Function *Caller,
                                    const Function *Callee) const = 0;
   virtual bool areTypesABICompatible(const Function *Caller,
@@ -2315,20 +2320,22 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                            Type *ExpectedType) override {
     return Impl.getOrCreateResultFromMemIntrinsic(Inst, ExpectedType);
   }
-  Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                  unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign,
-                                  unsigned DestAlign) const override {
+  Type *getMemcpyLoopLoweringType(
+      LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
+      unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicElementSize) const override {
     return Impl.getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace,
-                                          DestAddrSpace, SrcAlign, DestAlign);
+                                          DestAddrSpace, SrcAlign, DestAlign,
+                                          AtomicElementSize);
   }
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const override {
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const override {
     Impl.getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
                                            SrcAddrSpace, DestAddrSpace,
-                                           SrcAlign, DestAlign);
+                                           SrcAlign, DestAlign, AtomicCpySize);
   }
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const override {

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d0d4f43e98db..42d38ac570fd9 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -703,16 +703,21 @@ class TargetTransformInfoImplBase {
 
   Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                   unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const {
-    return Type::getInt8Ty(Context);
+                                  unsigned SrcAlign, unsigned DestAlign,
+                                  Optional<uint32_t> AtomicElementSize) const {
+    return AtomicElementSize ? Type::getIntNTy(Context, *AtomicElementSize * 8)
+                             : Type::getInt8Ty(Context);
   }
 
   void getMemcpyLoopResidualLoweringType(
       SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
       unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-      unsigned SrcAlign, unsigned DestAlign) const {
-    for (unsigned i = 0; i != RemainingBytes; ++i)
-      OpsOut.push_back(Type::getInt8Ty(Context));
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const {
+    unsigned OpSizeInBytes = AtomicCpySize ? *AtomicCpySize : 1;
+    Type *OpType = Type::getIntNTy(Context, OpSizeInBytes * 8);
+    for (unsigned i = 0; i != RemainingBytes; i += OpSizeInBytes)
+      OpsOut.push_back(OpType);
   }
 
   bool areInlineCompatible(const Function *Caller,

diff  --git a/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h b/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
index a46b7d4d17847..acf59ff580a4e 100644
--- a/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
+++ b/llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h
@@ -14,8 +14,11 @@
 #ifndef LLVM_TRANSFORMS_UTILS_LOWERMEMINTRINSICS_H
 #define LLVM_TRANSFORMS_UTILS_LOWERMEMINTRINSICS_H
 
+#include "llvm/ADT/Optional.h"
+
 namespace llvm {
 
+class AtomicMemCpyInst;
 class ConstantInt;
 class Instruction;
 class MemCpyInst;
@@ -32,7 +35,8 @@ void createMemCpyLoopUnknownSize(Instruction *InsertBefore, Value *SrcAddr,
                                  Value *DstAddr, Value *CopyLen, Align SrcAlign,
                                  Align DestAlign, bool SrcIsVolatile,
                                  bool DstIsVolatile, bool CanOverlap,
-                                 const TargetTransformInfo &TTI);
+                                 const TargetTransformInfo &TTI,
+                                 Optional<unsigned> AtomicSize = None);
 
 /// Emit a loop implementing the semantics of an llvm.memcpy whose size is a
 /// compile time constant. Loop is inserted at \p InsertBefore.
@@ -40,7 +44,8 @@ void createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
                                Value *DstAddr, ConstantInt *CopyLen,
                                Align SrcAlign, Align DestAlign,
                                bool SrcIsVolatile, bool DstIsVolatile,
-                               bool CanOverlap, const TargetTransformInfo &TTI);
+                               bool CanOverlap, const TargetTransformInfo &TTI,
+                               Optional<uint32_t> AtomicCpySize = None);
 
 /// Expand \p MemCpy as a loop. \p MemCpy is not deleted.
 void expandMemCpyAsLoop(MemCpyInst *MemCpy, const TargetTransformInfo &TTI,
@@ -52,6 +57,11 @@ void expandMemMoveAsLoop(MemMoveInst *MemMove);
 /// Expand \p MemSet as a loop. \p MemSet is not deleted.
 void expandMemSetAsLoop(MemSetInst *MemSet);
 
+/// Expand \p AtomicMemCpy as a loop. \p AtomicMemCpy is not deleted.
+void expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemCpy,
+                              const TargetTransformInfo &TTI,
+                              ScalarEvolution *SE);
+
 } // End llvm namespace
 
 #endif

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7366c7008c30f..eded89b16c223 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -976,18 +976,21 @@ Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic(
 
 Type *TargetTransformInfo::getMemcpyLoopLoweringType(
     LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
-    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign) const {
+    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicElementSize) const {
   return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAddrSpace,
-                                            DestAddrSpace, SrcAlign, DestAlign);
+                                            DestAddrSpace, SrcAlign, DestAlign,
+                                            AtomicElementSize);
 }
 
 void TargetTransformInfo::getMemcpyLoopResidualLoweringType(
     SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
     unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-    unsigned SrcAlign, unsigned DestAlign) const {
-  TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes,
-                                             SrcAddrSpace, DestAddrSpace,
-                                             SrcAlign, DestAlign);
+    unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicCpySize) const {
+  TTIImpl->getMemcpyLoopResidualLoweringType(
+      OpsOut, Context, RemainingBytes, SrcAddrSpace, DestAddrSpace, SrcAlign,
+      DestAlign, AtomicCpySize);
 }
 
 bool TargetTransformInfo::areInlineCompatible(const Function *Caller,

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index bdd22a4614f4f..0afebe0f5a12f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -410,11 +410,14 @@ bool GCNTTIImpl::isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes,
 // unaligned access is legal?
 //
 // FIXME: This could use fine tuning and microbenchmarks.
-Type *GCNTTIImpl::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
-                                            unsigned SrcAddrSpace,
-                                            unsigned DestAddrSpace,
-                                            unsigned SrcAlign,
-                                            unsigned DestAlign) const {
+Type *GCNTTIImpl::getMemcpyLoopLoweringType(
+    LLVMContext &Context, Value *Length, unsigned SrcAddrSpace,
+    unsigned DestAddrSpace, unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicElementSize) const {
+
+  if (AtomicElementSize)
+    return Type::getIntNTy(Context, *AtomicElementSize * 8);
+
   unsigned MinAlign = std::min(SrcAlign, DestAlign);
 
   // A (multi-)dword access at an address == 2 (mod 4) will be decomposed by the
@@ -439,11 +442,17 @@ Type *GCNTTIImpl::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
 }
 
 void GCNTTIImpl::getMemcpyLoopResidualLoweringType(
-  SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
-  unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
-  unsigned SrcAlign, unsigned DestAlign) const {
+    SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
+    unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
+    unsigned SrcAlign, unsigned DestAlign,
+    Optional<uint32_t> AtomicCpySize) const {
   assert(RemainingBytes < 16);
 
+  if (AtomicCpySize)
+    BaseT::getMemcpyLoopResidualLoweringType(
+        OpsOut, Context, RemainingBytes, SrcAddrSpace, DestAddrSpace, SrcAlign,
+        DestAlign, AtomicCpySize);
+
   unsigned MinAlign = std::min(SrcAlign, DestAlign);
 
   if (MinAlign != 2) {

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index 4743042f5faea..ebeb05e885124 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -135,15 +135,14 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
                                     unsigned AddrSpace) const;
   Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length,
                                   unsigned SrcAddrSpace, unsigned DestAddrSpace,
-                                  unsigned SrcAlign, unsigned DestAlign) const;
-
-  void getMemcpyLoopResidualLoweringType(SmallVectorImpl<Type *> &OpsOut,
-                                         LLVMContext &Context,
-                                         unsigned RemainingBytes,
-                                         unsigned SrcAddrSpace,
-                                         unsigned DestAddrSpace,
-                                         unsigned SrcAlign,
-                                         unsigned DestAlign) const;
+                                  unsigned SrcAlign, unsigned DestAlign,
+                                  Optional<uint32_t> AtomicElementSize) const;
+
+  void getMemcpyLoopResidualLoweringType(
+      SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context,
+      unsigned RemainingBytes, unsigned SrcAddrSpace, unsigned DestAddrSpace,
+      unsigned SrcAlign, unsigned DestAlign,
+      Optional<uint32_t> AtomicCpySize) const;
   unsigned getMaxInterleaveFactor(unsigned VF);
 
   bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info) const;

diff  --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
index 3848e349ece3f..b4acb1b2ae90d 100644
--- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp
@@ -21,7 +21,8 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
                                      Align SrcAlign, Align DstAlign,
                                      bool SrcIsVolatile, bool DstIsVolatile,
                                      bool CanOverlap,
-                                     const TargetTransformInfo &TTI) {
+                                     const TargetTransformInfo &TTI,
+                                     Optional<uint32_t> AtomicElementSize) {
   // No need to expand zero length copies.
   if (CopyLen->isZero())
     return;
@@ -41,9 +42,15 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
 
   Type *TypeOfCopyLen = CopyLen->getType();
   Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
-      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(),
+      AtomicElementSize);
+  assert((!AtomicElementSize || !LoopOpType->isVectorTy()) &&
+         "Atomic memcpy lowering is not supported for vector operand type");
 
   unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+  assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) &&
+      "Atomic memcpy lowering is not supported for selected operand size");
+
   uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize;
 
   if (LoopEndCount != 0) {
@@ -90,6 +97,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
       // Indicate that stores don't overlap loads.
       Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
     }
+    if (AtomicElementSize) {
+      Load->setAtomic(AtomicOrdering::Unordered);
+      Store->setAtomic(AtomicOrdering::Unordered);
+    }
     Value *NewIndex =
         LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1U));
     LoopIndex->addIncoming(NewIndex, LoopBB);
@@ -109,7 +120,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
     SmallVector<Type *, 5> RemainingOps;
     TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
                                           SrcAS, DstAS, SrcAlign.value(),
-                                          DstAlign.value());
+                                          DstAlign.value(), AtomicElementSize);
 
     for (auto OpTy : RemainingOps) {
       Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied));
@@ -117,6 +128,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
 
       // Calaculate the new index
       unsigned OperandSize = DL.getTypeStoreSize(OpTy);
+      assert(
+          (!AtomicElementSize || OperandSize % *AtomicElementSize == 0) &&
+          "Atomic memcpy lowering is not supported for selected operand size");
+
       uint64_t GepIndex = BytesCopied / OperandSize;
       assert(GepIndex * OperandSize == BytesCopied &&
              "Division should have no Remainder!");
@@ -147,6 +162,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr,
         // Indicate that stores don't overlap loads.
         Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
       }
+      if (AtomicElementSize) {
+        Load->setAtomic(AtomicOrdering::Unordered);
+        Store->setAtomic(AtomicOrdering::Unordered);
+      }
       BytesCopied += OperandSize;
     }
   }
@@ -159,7 +178,8 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
                                        Value *CopyLen, Align SrcAlign,
                                        Align DstAlign, bool SrcIsVolatile,
                                        bool DstIsVolatile, bool CanOverlap,
-                                       const TargetTransformInfo &TTI) {
+                                       const TargetTransformInfo &TTI,
+                                       Optional<uint32_t> AtomicElementSize) {
   BasicBlock *PreLoopBB = InsertBefore->getParent();
   BasicBlock *PostLoopBB =
       PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion");
@@ -176,8 +196,13 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
   unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
 
   Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
-      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value());
+      Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(),
+      AtomicElementSize);
+  assert((!AtomicElementSize || !LoopOpType->isVectorTy()) &&
+         "Atomic memcpy lowering is not supported for vector operand type");
   unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
+  assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) &&
+         "Atomic memcpy lowering is not supported for selected operand size");
 
   IRBuilder<> PLBuilder(PreLoopBB->getTerminator());
 
@@ -225,14 +250,27 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
     // Indicate that stores don't overlap loads.
     Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
   }
+  if (AtomicElementSize) {
+    Load->setAtomic(AtomicOrdering::Unordered);
+    Store->setAtomic(AtomicOrdering::Unordered);
+  }
   Value *NewIndex =
       LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U));
   LoopIndex->addIncoming(NewIndex, LoopBB);
 
-  if (!LoopOpIsInt8) {
-   // Add in the
-   Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize);
-   Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual);
+  bool requiresResidual =
+      !LoopOpIsInt8 && !(AtomicElementSize && LoopOpSize == AtomicElementSize);
+  if (requiresResidual) {
+    Type *ResLoopOpType = AtomicElementSize
+                              ? Type::getIntNTy(Ctx, *AtomicElementSize * 8)
+                              : Int8Type;
+    unsigned ResLoopOpSize = DL.getTypeStoreSize(ResLoopOpType);
+    assert((ResLoopOpSize == AtomicElementSize ? *AtomicElementSize : 1) &&
+           "Store size is expected to match type size");
+
+    // Add in the
+    Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize);
+    Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual);
 
     // Loop body for the residual copy.
     BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual",
@@ -267,30 +305,34 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore,
         ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index");
     ResidualIndex->addIncoming(Zero, ResHeaderBB);
 
-    Value *SrcAsInt8 =
-        ResBuilder.CreateBitCast(SrcAddr, PointerType::get(Int8Type, SrcAS));
-    Value *DstAsInt8 =
-        ResBuilder.CreateBitCast(DstAddr, PointerType::get(Int8Type, DstAS));
+    Value *SrcAsResLoopOpType = ResBuilder.CreateBitCast(
+        SrcAddr, PointerType::get(ResLoopOpType, SrcAS));
+    Value *DstAsResLoopOpType = ResBuilder.CreateBitCast(
+        DstAddr, PointerType::get(ResLoopOpType, DstAS));
     Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex);
-    Value *SrcGEP =
-        ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset);
-    LoadInst *Load = ResBuilder.CreateAlignedLoad(Int8Type, SrcGEP,
+    Value *SrcGEP = ResBuilder.CreateInBoundsGEP(
+        ResLoopOpType, SrcAsResLoopOpType, FullOffset);
+    LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP,
                                                   PartSrcAlign, SrcIsVolatile);
     if (!CanOverlap) {
       // Set alias scope for loads.
       Load->setMetadata(LLVMContext::MD_alias_scope,
                         MDNode::get(Ctx, NewScope));
     }
-    Value *DstGEP =
-        ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset);
+    Value *DstGEP = ResBuilder.CreateInBoundsGEP(
+        ResLoopOpType, DstAsResLoopOpType, FullOffset);
     StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign,
                                                      DstIsVolatile);
     if (!CanOverlap) {
       // Indicate that stores don't overlap loads.
       Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope));
     }
-    Value *ResNewIndex =
-        ResBuilder.CreateAdd(ResidualIndex, ConstantInt::get(CopyLenType, 1U));
+    if (AtomicElementSize) {
+      Load->setAtomic(AtomicOrdering::Unordered);
+      Store->setAtomic(AtomicOrdering::Unordered);
+    }
+    Value *ResNewIndex = ResBuilder.CreateAdd(
+        ResidualIndex, ConstantInt::get(CopyLenType, ResLoopOpSize));
     ResidualIndex->addIncoming(ResNewIndex, ResLoopBB);
 
     // Create the loop branch condition.
@@ -471,17 +513,21 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr,
                            NewBB);
 }
 
-void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
-                              const TargetTransformInfo &TTI,
-                              ScalarEvolution *SE) {
-  bool CanOverlap = true;
+template <typename T>
+static bool canOverlap(MemTransferBase<T> *Memcpy, ScalarEvolution *SE) {
   if (SE) {
     auto *SrcSCEV = SE->getSCEV(Memcpy->getRawSource());
     auto *DestSCEV = SE->getSCEV(Memcpy->getRawDest());
     if (SE->isKnownPredicateAt(CmpInst::ICMP_NE, SrcSCEV, DestSCEV, Memcpy))
-      CanOverlap = false;
+      return false;
   }
+  return true;
+}
 
+void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy,
+                              const TargetTransformInfo &TTI,
+                              ScalarEvolution *SE) {
+  bool CanOverlap = canOverlap(Memcpy, SE);
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) {
     createMemCpyLoopKnownSize(
         /* InsertBefore */ Memcpy,
@@ -528,3 +574,35 @@ void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
                    /* Alignment */ Memset->getDestAlign().valueOrOne(),
                    Memset->isVolatile());
 }
+
+void llvm::expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemcpy,
+                                    const TargetTransformInfo &TTI,
+                                    ScalarEvolution *SE) {
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(AtomicMemcpy->getLength())) {
+    createMemCpyLoopKnownSize(
+        /* InsertBefore */ AtomicMemcpy,
+        /* SrcAddr */ AtomicMemcpy->getRawSource(),
+        /* DstAddr */ AtomicMemcpy->getRawDest(),
+        /* CopyLen */ CI,
+        /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(),
+        /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(),
+        /* SrcIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* DstIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec.
+        /* TargetTransformInfo */ TTI,
+        /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes());
+  } else {
+    createMemCpyLoopUnknownSize(
+        /* InsertBefore */ AtomicMemcpy,
+        /* SrcAddr */ AtomicMemcpy->getRawSource(),
+        /* DstAddr */ AtomicMemcpy->getRawDest(),
+        /* CopyLen */ AtomicMemcpy->getLength(),
+        /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(),
+        /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(),
+        /* SrcIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* DstIsVolatile */ AtomicMemcpy->isVolatile(),
+        /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec.
+        /* TargetTransformInfo */ TTI,
+        /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes());
+  }
+}

diff  --git a/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp b/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
index df86e16bf641f..62afd91a704ee 100644
--- a/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
+++ b/llvm/unittests/Transforms/Utils/MemTransferLowering.cpp
@@ -174,4 +174,94 @@ TEST_F(MemTransferLowerTest, VecMemCpyKnownLength) {
 
   MPM.run(*M, MAM);
 }
+
+TEST_F(MemTransferLowerTest, AtomicMemCpyKnownLength) {
+  ParseAssembly("declare void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32*, "
+                "i32 *, i64, i32)\n"
+                "define void @foo(i32* %dst, i32* %src, i64 %n) optsize {\n"
+                "entry:\n"
+                "  %is_not_equal = icmp ne i32* %dst, %src\n"
+                "  br i1 %is_not_equal, label %memcpy, label %exit\n"
+                "memcpy:\n"
+                "  call void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32* "
+                "%dst, i32* %src, "
+                "i64 1024, i32 4)\n"
+                "  br label %exit\n"
+                "exit:\n"
+                "  ret void\n"
+                "}\n");
+
+  FunctionPassManager FPM;
+  FPM.addPass(ForwardingPass(
+      [=](Function &F, FunctionAnalysisManager &FAM) -> PreservedAnalyses {
+        TargetTransformInfo TTI(M->getDataLayout());
+        auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
+        Instruction *Inst = &MemCpyBB->front();
+        assert(isa<AtomicMemCpyInst>(Inst) &&
+               "Expecting llvm.memcpy.p0i8.i64 instructon");
+        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
+        expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
+        auto *CopyLoopBB = getBasicBlockByName(F, "load-store-loop");
+        Instruction *LoadInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Load, 1);
+        EXPECT_TRUE(LoadInst->isAtomic());
+        EXPECT_NE(LoadInst->getMetadata(LLVMContext::MD_alias_scope), nullptr);
+        Instruction *StoreInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Store, 1);
+        EXPECT_TRUE(StoreInst->isAtomic());
+        EXPECT_NE(StoreInst->getMetadata(LLVMContext::MD_noalias), nullptr);
+        return PreservedAnalyses::none();
+      }));
+  MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+
+  MPM.run(*M, MAM);
+}
+
+TEST_F(MemTransferLowerTest, AtomicMemCpyUnKnownLength) {
+  ParseAssembly("declare void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32*, "
+                "i32 *, i64, i32)\n"
+                "define void @foo(i32* %dst, i32* %src, i64 %n) optsize {\n"
+                "entry:\n"
+                "  %is_not_equal = icmp ne i32* %dst, %src\n"
+                "  br i1 %is_not_equal, label %memcpy, label %exit\n"
+                "memcpy:\n"
+                "  call void "
+                "@llvm.memcpy.element.unordered.atomic.p0i32.p0i32.i64(i32* "
+                "%dst, i32* %src, "
+                "i64 %n, i32 4)\n"
+                "  br label %exit\n"
+                "exit:\n"
+                "  ret void\n"
+                "}\n");
+
+  FunctionPassManager FPM;
+  FPM.addPass(ForwardingPass(
+      [=](Function &F, FunctionAnalysisManager &FAM) -> PreservedAnalyses {
+        TargetTransformInfo TTI(M->getDataLayout());
+        auto *MemCpyBB = getBasicBlockByName(F, "memcpy");
+        Instruction *Inst = &MemCpyBB->front();
+        assert(isa<AtomicMemCpyInst>(Inst) &&
+               "Expecting llvm.memcpy.p0i8.i64 instructon");
+        AtomicMemCpyInst *MemCpyI = cast<AtomicMemCpyInst>(Inst);
+        auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F);
+        expandAtomicMemCpyAsLoop(MemCpyI, TTI, &SE);
+        auto *CopyLoopBB = getBasicBlockByName(F, "loop-memcpy-expansion");
+        Instruction *LoadInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Load, 1);
+        EXPECT_TRUE(LoadInst->isAtomic());
+        EXPECT_NE(LoadInst->getMetadata(LLVMContext::MD_alias_scope), nullptr);
+        Instruction *StoreInst =
+            getInstructionByOpcode(*CopyLoopBB, Instruction::Store, 1);
+        EXPECT_TRUE(StoreInst->isAtomic());
+        EXPECT_NE(StoreInst->getMetadata(LLVMContext::MD_noalias), nullptr);
+        return PreservedAnalyses::none();
+      }));
+  MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+
+  MPM.run(*M, MAM);
+}
 } // namespace


        


More information about the llvm-commits mailing list