[flang] [llvm] Reland "[flang][cuda] Add support for derived-type initialization on device #172568" (PR #172913)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 18 14:01:43 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

#<!-- -->172568

---

Patch is 84.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172913.diff


32 Files Affected:

- (modified) flang-rt/include/flang-rt/runtime/derived.h (+9-1) 
- (modified) flang-rt/include/flang-rt/runtime/work-queue.h (+17-7) 
- (modified) flang-rt/lib/cuda/allocatable.cpp (+6-6) 
- (modified) flang-rt/lib/cuda/memmove-function.cpp (+18) 
- (modified) flang-rt/lib/cuda/pointer.cpp (+6-6) 
- (modified) flang-rt/lib/runtime/allocatable.cpp (+3-3) 
- (modified) flang-rt/lib/runtime/derived.cpp (+24-6) 
- (modified) flang-rt/lib/runtime/pointer.cpp (+4-2) 
- (modified) flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h (+27) 
- (modified) flang/include/flang/Runtime/CUDA/allocatable.h (+2-2) 
- (modified) flang/include/flang/Runtime/CUDA/memmove-function.h (+6) 
- (modified) flang/include/flang/Runtime/CUDA/pointer.h (+2-2) 
- (modified) flang/include/flang/Runtime/allocatable.h (+9-1) 
- (modified) flang/include/flang/Runtime/freestanding-tools.h (+7-1) 
- (modified) flang/include/flang/Runtime/pointer.h (+8-1) 
- (modified) flang/lib/Lower/Allocatable.cpp (+1-9) 
- (modified) flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp (+4-3) 
- (modified) flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp (+6-1) 
- (modified) flang/test/Fir/CUDA/cuda-allocate.fir (+6-6) 
- (removed) flang/test/Lower/CUDA/TODO/cuda-allocate-default-init.cuf (-15) 
- (modified) flang/test/Lower/Intrinsics/c_loc.f90 (+1-1) 
- (modified) flang/test/Lower/OpenACC/acc-declare.f90 (+2-2) 
- (modified) flang/test/Lower/OpenMP/parallel-reduction-pointer-array.f90 (+1-1) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-pointer.f90 (+1-1) 
- (modified) flang/test/Lower/allocatable-polymorphic.f90 (+17-17) 
- (modified) flang/test/Lower/allocatable-runtime.f90 (+2-2) 
- (modified) flang/test/Lower/allocate-mold.f90 (+2-2) 
- (modified) flang/test/Lower/assign-statement.f90 (+1-1) 
- (modified) flang/test/Lower/nullify-polymorphic.f90 (+1-1) 
- (modified) flang/test/Lower/polymorphic.f90 (+1-1) 
- (modified) flang/test/Lower/volatile-allocatable.f90 (+9-9) 
- (modified) flang/test/Transforms/lower-repack-arrays.fir (+4-4) 


``````````diff
diff --git a/flang-rt/include/flang-rt/runtime/derived.h b/flang-rt/include/flang-rt/runtime/derived.h
index ac6962c57168c..20d022df56170 100644
--- a/flang-rt/include/flang-rt/runtime/derived.h
+++ b/flang-rt/include/flang-rt/runtime/derived.h
@@ -12,6 +12,7 @@
 #define FLANG_RT_RUNTIME_DERIVED_H_
 
 #include "flang/Common/api-attrs.h"
+#include "flang/Runtime/freestanding-tools.h"
 
 namespace Fortran::runtime::typeInfo {
 class DerivedType;
@@ -23,8 +24,15 @@ class Terminator;
 
 // Perform default component initialization, allocate automatic components.
 // Returns a STAT= code (0 when all's well).
+#ifdef RT_DEVICE_COMPILATION
 RT_API_ATTRS int Initialize(const Descriptor &, const typeInfo::DerivedType &,
-    Terminator &, bool hasStat = false, const Descriptor *errMsg = nullptr);
+    Terminator &, bool hasStat = false, const Descriptor *errMsg = nullptr,
+    MemcpyFct memcpyFct = &MemcpyWrapper);
+#else
+RT_API_ATTRS int Initialize(const Descriptor &, const typeInfo::DerivedType &,
+    Terminator &, bool hasStat = false, const Descriptor *errMsg = nullptr,
+    MemcpyFct memcpyFct = &Fortran::runtime::memcpy);
+#endif
 
 // Initializes an object clone from the original object.
 // Each allocatable member of the clone is allocated with the same bounds as
diff --git a/flang-rt/include/flang-rt/runtime/work-queue.h b/flang-rt/include/flang-rt/runtime/work-queue.h
index 7d7f8ad991a57..54a7457741356 100644
--- a/flang-rt/include/flang-rt/runtime/work-queue.h
+++ b/flang-rt/include/flang-rt/runtime/work-queue.h
@@ -249,12 +249,15 @@ class ElementsOverComponents : public Elementwise, public Componentwise {
 class InitializeTicket : public ImmediateTicketRunner<InitializeTicket>,
                          private ElementsOverComponents {
 public:
-  RT_API_ATTRS InitializeTicket(
-      const Descriptor &instance, const typeInfo::DerivedType &derived)
+  RT_API_ATTRS InitializeTicket(const Descriptor &instance,
+      const typeInfo::DerivedType &derived, MemcpyFct memcpyFct)
       : ImmediateTicketRunner<InitializeTicket>{*this},
-        ElementsOverComponents{instance, derived} {}
+        ElementsOverComponents{instance, derived}, memcpyFct_{memcpyFct} {}
   RT_API_ATTRS int Begin(WorkQueue &);
   RT_API_ATTRS int Continue(WorkQueue &);
+
+private:
+  MemcpyFct memcpyFct_;
 };
 
 // Initializes one derived type instance from the value of another
@@ -448,12 +451,19 @@ class WorkQueue {
 
   // APIs for particular tasks.  These can return StatOk if the work is
   // completed immediately.
-  RT_API_ATTRS int BeginInitialize(
-      const Descriptor &descriptor, const typeInfo::DerivedType &derived) {
+#ifdef RT_DEVICE_COMPILATION
+  RT_API_ATTRS int BeginInitialize(const Descriptor &descriptor,
+      const typeInfo::DerivedType &derived,
+      MemcpyFct memcpyFct = &MemcpyWrapper) {
+#else
+  RT_API_ATTRS int BeginInitialize(const Descriptor &descriptor,
+      const typeInfo::DerivedType &derived,
+      MemcpyFct memcpyFct = &Fortran::runtime::memcpy) {
+#endif
     if (runTicketsImmediately_) {
-      return InitializeTicket{descriptor, derived}.Run(*this);
+      return InitializeTicket{descriptor, derived, memcpyFct}.Run(*this);
     } else {
-      StartTicket().u.emplace<InitializeTicket>(descriptor, derived);
+      StartTicket().u.emplace<InitializeTicket>(descriptor, derived, memcpyFct);
       return StatContinue;
     }
   }
diff --git a/flang-rt/lib/cuda/allocatable.cpp b/flang-rt/lib/cuda/allocatable.cpp
index 662703dfb6321..0a7828f8016d5 100644
--- a/flang-rt/lib/cuda/allocatable.cpp
+++ b/flang-rt/lib/cuda/allocatable.cpp
@@ -25,9 +25,9 @@ RT_EXT_API_GROUP_BEGIN
 
 int RTDEF(CUFAllocatableAllocateSync)(Descriptor &desc, int64_t *stream,
     bool *pinned, bool hasStat, const Descriptor *errMsg,
-    const char *sourceFile, int sourceLine) {
-  int stat{RTNAME(CUFAllocatableAllocate)(
-      desc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
+    const char *sourceFile, int sourceLine, bool deviceInit) {
+  int stat{RTNAME(CUFAllocatableAllocate)(desc, stream, pinned, hasStat, errMsg,
+      sourceFile, sourceLine, deviceInit)};
 #ifndef RT_DEVICE_COMPILATION
   // Descriptor synchronization is only done when the allocation is done
   // from the host.
@@ -43,10 +43,10 @@ int RTDEF(CUFAllocatableAllocateSync)(Descriptor &desc, int64_t *stream,
 
 int RTDEF(CUFAllocatableAllocate)(Descriptor &desc, int64_t *stream,
     bool *pinned, bool hasStat, const Descriptor *errMsg,
-    const char *sourceFile, int sourceLine) {
+    const char *sourceFile, int sourceLine, bool deviceInit) {
   // Perform the standard allocation.
-  int stat{RTNAME(AllocatableAllocate)(
-      desc, stream, hasStat, errMsg, sourceFile, sourceLine)};
+  int stat{RTNAME(AllocatableAllocate)(desc, stream, hasStat, errMsg,
+      sourceFile, sourceLine, deviceInit ? &MemcpyHostToDevice : nullptr)};
   if (pinned) {
     // Set pinned according to stat. More infrastructre is needed to set it
     // closer to the actual allocation call.
diff --git a/flang-rt/lib/cuda/memmove-function.cpp b/flang-rt/lib/cuda/memmove-function.cpp
index a7eb0cf1a3e7a..8ebc1250a6687 100644
--- a/flang-rt/lib/cuda/memmove-function.cpp
+++ b/flang-rt/lib/cuda/memmove-function.cpp
@@ -32,4 +32,22 @@ void *MemmoveDeviceToDevice(void *dst, const void *src, std::size_t count) {
   return dst;
 }
 
+void *MemcpyHostToDevice(void *dst, const void *src, std::size_t count) {
+  // TODO: Use cudaMemcpyAsync when we have support for stream.
+  CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, count, cudaMemcpyHostToDevice));
+  return dst;
+}
+
+void *MemcpyDeviceToHost(void *dst, const void *src, std::size_t count) {
+  // TODO: Use cudaMemcpyAsync when we have support for stream.
+  CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToHost));
+  return dst;
+}
+
+void *MemcpyDeviceToDevice(void *dst, const void *src, std::size_t count) {
+  // TODO: Use cudaMemcpyAsync when we have support for stream.
+  CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, count, cudaMemcpyDeviceToDevice));
+  return dst;
+}
+
 } // namespace Fortran::runtime::cuda
diff --git a/flang-rt/lib/cuda/pointer.cpp b/flang-rt/lib/cuda/pointer.cpp
index f07b1a9b60924..bc990c5d27e21 100644
--- a/flang-rt/lib/cuda/pointer.cpp
+++ b/flang-rt/lib/cuda/pointer.cpp
@@ -24,10 +24,10 @@ RT_EXT_API_GROUP_BEGIN
 
 int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t *stream, bool *pinned,
     bool hasStat, const Descriptor *errMsg, const char *sourceFile,
-    int sourceLine) {
+    int sourceLine, bool deviceInit) {
   // Perform the standard allocation.
-  int stat{
-      RTNAME(PointerAllocate)(desc, hasStat, errMsg, sourceFile, sourceLine)};
+  int stat{RTNAME(PointerAllocate)(desc, hasStat, errMsg, sourceFile,
+      sourceLine, deviceInit ? &MemcpyHostToDevice : nullptr)};
   if (pinned) {
     // Set pinned according to stat. More infrastructre is needed to set it
     // closer to the actual allocation call.
@@ -38,9 +38,9 @@ int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t *stream, bool *pinned,
 
 int RTDEF(CUFPointerAllocateSync)(Descriptor &desc, int64_t *stream,
     bool *pinned, bool hasStat, const Descriptor *errMsg,
-    const char *sourceFile, int sourceLine) {
-  int stat{RTNAME(CUFPointerAllocate)(
-      desc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
+    const char *sourceFile, int sourceLine, bool deviceInit) {
+  int stat{RTNAME(CUFPointerAllocate)(desc, stream, pinned, hasStat, errMsg,
+      sourceFile, sourceLine, deviceInit)};
 #ifndef RT_DEVICE_COMPILATION
   // Descriptor synchronization is only done when the allocation is done
   // from the host.
diff --git a/flang-rt/lib/runtime/allocatable.cpp b/flang-rt/lib/runtime/allocatable.cpp
index f724f0a20884b..5b3db1e47238b 100644
--- a/flang-rt/lib/runtime/allocatable.cpp
+++ b/flang-rt/lib/runtime/allocatable.cpp
@@ -135,7 +135,7 @@ void RTDEF(AllocatableApplyMold)(
 
 int RTDEF(AllocatableAllocate)(Descriptor &descriptor,
     std::int64_t *asyncObject, bool hasStat, const Descriptor *errMsg,
-    const char *sourceFile, int sourceLine) {
+    const char *sourceFile, int sourceLine, MemcpyFct memcpyFct) {
   Terminator terminator{sourceFile, sourceLine};
   if (!descriptor.IsAllocatable()) {
     return ReturnError(terminator, StatInvalidDescriptor, errMsg, hasStat);
@@ -148,8 +148,8 @@ int RTDEF(AllocatableAllocate)(Descriptor &descriptor,
       if (const DescriptorAddendum * addendum{descriptor.Addendum()}) {
         if (const auto *derived{addendum->derivedType()}) {
           if (!derived->noInitializationNeeded()) {
-            stat =
-                Initialize(descriptor, *derived, terminator, hasStat, errMsg);
+            stat = Initialize(
+                descriptor, *derived, terminator, hasStat, errMsg, memcpyFct);
           }
         }
       }
diff --git a/flang-rt/lib/runtime/derived.cpp b/flang-rt/lib/runtime/derived.cpp
index 7e50674631624..7fc426b9efc9a 100644
--- a/flang-rt/lib/runtime/derived.cpp
+++ b/flang-rt/lib/runtime/derived.cpp
@@ -13,6 +13,7 @@
 #include "flang-rt/runtime/tools.h"
 #include "flang-rt/runtime/type-info.h"
 #include "flang-rt/runtime/work-queue.h"
+#include "flang/Runtime/CUDA/memmove-function.h"
 
 namespace Fortran::runtime {
 
@@ -32,9 +33,9 @@ static RT_API_ATTRS void GetComponentExtents(SubscriptValue (&extents)[maxRank],
 
 RT_API_ATTRS int Initialize(const Descriptor &instance,
     const typeInfo::DerivedType &derived, Terminator &terminator, bool,
-    const Descriptor *) {
+    const Descriptor *, MemcpyFct memcpyFct) {
   WorkQueue workQueue{terminator};
-  int status{workQueue.BeginInitialize(instance, derived)};
+  int status{workQueue.BeginInitialize(instance, derived, memcpyFct)};
   return status == StatContinue ? workQueue.Run() : status;
 }
 
@@ -72,7 +73,11 @@ RT_API_ATTRS int InitializeTicket::Continue(WorkQueue &workQueue) {
       // Explicit initialization of data pointers and
       // non-allocatable non-automatic components
       std::size_t bytes{component_->SizeInBytes(instance_)};
-      runtime::memcpy(rawComponent, init, bytes);
+      if (memcpyFct_) {
+        memcpyFct_(rawComponent, init, bytes);
+      } else {
+        Fortran::runtime::memcpy(rawComponent, init, bytes);
+      }
     } else if (component_->genre() == typeInfo::Component::Genre::Pointer ||
         component_->genre() == typeInfo::Component::Genre::PointerDevice) {
       // Data pointers without explicit initialization are established
@@ -110,20 +115,33 @@ RT_API_ATTRS int InitializeTicket::Continue(WorkQueue &workQueue) {
             chunk = done;
           }
           char *uninitialized{rawInstance + done * *stride};
-          runtime::memcpy(uninitialized, rawInstance, chunk * *stride);
+          if (memcpyFct_) {
+            memcpyFct_(uninitialized, rawInstance, chunk * *stride);
+          } else {
+            Fortran::runtime::memcpy(
+                uninitialized, rawInstance, chunk * *stride);
+          }
           done += chunk;
         }
       } else {
         for (std::size_t done{1}; done < elements_; ++done) {
           char *uninitialized{rawInstance + done * *stride};
-          runtime::memcpy(uninitialized, rawInstance, elementBytes);
+          if (memcpyFct_) {
+            memcpyFct_(uninitialized, rawInstance, elementBytes);
+          } else {
+            Fortran::runtime::memcpy(uninitialized, rawInstance, elementBytes);
+          }
         }
       }
     } else { // one at a time with subscription
       for (Elementwise::Advance(); !Elementwise::IsComplete();
           Elementwise::Advance()) {
         char *element{instance_.Element<char>(subscripts_)};
-        runtime::memcpy(element, rawInstance, elementBytes);
+        if (memcpyFct_) {
+          memcpyFct_(element, rawInstance, elementBytes);
+        } else {
+          Fortran::runtime::memcpy(element, rawInstance, elementBytes);
+        }
       }
     }
   }
diff --git a/flang-rt/lib/runtime/pointer.cpp b/flang-rt/lib/runtime/pointer.cpp
index f8ada65541a1a..0832b5656f1ab 100644
--- a/flang-rt/lib/runtime/pointer.cpp
+++ b/flang-rt/lib/runtime/pointer.cpp
@@ -157,7 +157,8 @@ RT_API_ATTRS void *AllocateValidatedPointerPayload(
 }
 
 int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
-    const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
+    const Descriptor *errMsg, const char *sourceFile, int sourceLine,
+    MemcpyFct memcpyFct) {
   Terminator terminator{sourceFile, sourceLine};
   if (!pointer.IsPointer()) {
     return ReturnError(terminator, StatInvalidDescriptor, errMsg, hasStat);
@@ -179,7 +180,8 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
   if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
     if (const auto *derived{addendum->derivedType()}) {
       if (!derived->noInitializationNeeded()) {
-        stat = Initialize(pointer, *derived, terminator, hasStat, errMsg);
+        stat = Initialize(
+            pointer, *derived, terminator, hasStat, errMsg, memcpyFct);
       }
     }
   }
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index 98d7de81c7f08..960405ee0006f 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -252,6 +252,33 @@ constexpr TypeBuilderFunc getModel<void (*)(int)>() {
   };
 }
 template <>
+constexpr TypeBuilderFunc
+getModel<void *(*)(void *, const void *, unsigned long)>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    auto voidPtrTy =
+        fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
+    auto unsignedLongTy =
+        mlir::IntegerType::get(context, 8 * sizeof(unsigned long));
+    auto funcTy = mlir::FunctionType::get(
+        context, {voidPtrTy, voidPtrTy, unsignedLongTy}, {voidPtrTy});
+    return fir::LLVMPointerType::get(context, funcTy);
+  };
+}
+#ifdef _MSC_VER
+template <>
+constexpr TypeBuilderFunc
+getModel<void *(*)(void *, const void *, unsigned __int64)>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    auto voidPtrTy =
+        fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
+    auto uint64Ty = mlir::IntegerType::get(context, 64);
+    auto funcTy = mlir::FunctionType::get(
+        context, {voidPtrTy, voidPtrTy, uint64Ty}, {voidPtrTy});
+    return fir::LLVMPointerType::get(context, funcTy);
+  };
+}
+#endif
+template <>
 constexpr TypeBuilderFunc getModel<void **>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
     return fir::ReferenceType::get(
diff --git a/flang/include/flang/Runtime/CUDA/allocatable.h b/flang/include/flang/Runtime/CUDA/allocatable.h
index 97f24bc34bfb8..d5a649594ae92 100644
--- a/flang/include/flang/Runtime/CUDA/allocatable.h
+++ b/flang/include/flang/Runtime/CUDA/allocatable.h
@@ -20,14 +20,14 @@ extern "C" {
 int RTDECL(CUFAllocatableAllocate)(Descriptor &, int64_t *stream = nullptr,
     bool *pinned = nullptr, bool hasStat = false,
     const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+    int sourceLine = 0, bool deviceInit = false);
 
 /// Perform allocation of the descriptor with synchronization of it when
 /// necessary.
 int RTDECL(CUFAllocatableAllocateSync)(Descriptor &, int64_t *stream = nullptr,
     bool *pinned = nullptr, bool hasStat = false,
     const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+    int sourceLine = 0, bool deviceInit = false);
 
 /// Perform allocation of the descriptor without synchronization. Assign data
 /// from source.
diff --git a/flang/include/flang/Runtime/CUDA/memmove-function.h b/flang/include/flang/Runtime/CUDA/memmove-function.h
index 74d6a05eff4c9..765600db4b620 100644
--- a/flang/include/flang/Runtime/CUDA/memmove-function.h
+++ b/flang/include/flang/Runtime/CUDA/memmove-function.h
@@ -19,5 +19,11 @@ void *MemmoveDeviceToHost(void *dst, const void *src, std::size_t count);
 
 void *MemmoveDeviceToDevice(void *dst, const void *src, std::size_t count);
 
+void *MemcpyHostToDevice(void *dst, const void *src, std::size_t count);
+
+void *MemcpyDeviceToHost(void *dst, const void *src, std::size_t count);
+
+void *MemcpyDeviceToDevice(void *dst, const void *src, std::size_t count);
+
 } // namespace Fortran::runtime::cuda
 #endif // FORTRAN_RUNTIME_CUDA_MEMMOVE_FUNCTION_H_
diff --git a/flang/include/flang/Runtime/CUDA/pointer.h b/flang/include/flang/Runtime/CUDA/pointer.h
index b845fd59114d4..4e49691d127e1 100644
--- a/flang/include/flang/Runtime/CUDA/pointer.h
+++ b/flang/include/flang/Runtime/CUDA/pointer.h
@@ -20,14 +20,14 @@ extern "C" {
 int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t *stream = nullptr,
     bool *pinned = nullptr, bool hasStat = false,
     const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+    int sourceLine = 0, bool deviceInit = false);
 
 /// Perform allocation of the descriptor with synchronization of it when
 /// necessary.
 int RTDECL(CUFPointerAllocateSync)(Descriptor &, int64_t *stream = nullptr,
     bool *pinned = nullptr, bool hasStat = false,
     const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+    int sourceLine = 0, bool deviceInit = false);
 
 /// Perform allocation of the descriptor without synchronization. Assign data
 /// from source.
diff --git a/flang/include/flang/Runtime/allocatable.h b/flang/include/flang/Runtime/allocatable.h
index 863c07494e7c3..ba065331e3922 100644
--- a/flang/include/flang/Runtime/allocatable.h
+++ b/flang/include/flang/Runtime/allocatable.h
@@ -13,6 +13,7 @@
 
 #include "flang/Runtime/descriptor-consts.h"
 #include "flang/Runtime/entry-names.h"
+#include "flang/Runtime/freestanding-tools.h"
 
 namespace Fortran::runtime {
 
@@ -94,10 +95,17 @@ int RTDECL(AllocatableCheckLengthParameter)(Descriptor &,
 // Successfully allocated memory is initialized if the allocatable has a
 // derived type, and is always initialized by AllocatableAllocateSource().
 // Performs all necessary coarray synchronization and validation actions.
+#ifdef RT_DEVICE_COMPILATION
 int RTDECL(AllocatableAllocate)(Descriptor &,
     std::int64_t *asyncObject = nullptr, bool hasStat = false,
     const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+    int sourceLine = 0, MemcpyFct memcpyFct = &MemcpyWrapper);
+#else
+int RTDECL(AllocatableAllocate)(Descriptor &,
+    std::int64_t *asyncObject = nullptr, bool hasStat = false,
+    const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
+    int sourceLine = 0, MemcpyFct memcpyFct = &Fortran::runtime::memcpy);
+#endif
 int RTDECL(AllocatableAllocateSource)(Descriptor &, const Descriptor &source,
     bool hasStat = false, const Descriptor *errMsg = nullptr,
     const char *sourceFile = nullptr, int sourceLine = 0);
diff --git a/flang/include/flang/Runtime/freestanding-tools.h b/flang/include/flang/Runtime/freestanding-tools.h
index 7ef7cc74f213b..7ab06145c5d52 100644
--- a/flang/include/flang/Runtime/freestanding-tools.h
+++ b/flang/include/flang/Runtime/freestanding-tools.h
@@ -122,7 +122,7 @@ static inline RT_API_ATTRS void memcpy(
   __builtin_memcpy(dest, src, count);
 }
 #elif STD_MEMCPY_UNSUPPORTED
-static inline RT_API_ATTRS void memcpy(
+static inline RT_API_ATTRS void *memcpy(
     void *dest, const void *src, std::size_t count) {
   char *to{reinterpret_cast<char *>(dest)};
   const char *from{reinterpret_cast<const char *>(src)};
@@ -132,6 +132,7 @@ static inline RT_API_ATTRS void memcpy(
   while (count--) {
     *to++ = *from++;
   }
+  return dest;
 }
 #else
 using std::memcpy;
@@ -173,12 +174,17 @@ using std::memmove;
 #endif // !STD_MEMMOVE_UNSUPPORTED
 
 using MemmoveFct = void *(*)(void *, const void *, std::size_t);
+using MemcpyFct = void *(*)(void *, const void *, std::size_t);
 
 #ifdef RT_DEVICE_COMPILATION
 [[maybe_unused]] static RT_API_ATTRS void *MemmoveWrapper(
     void *dest, const void *src, std::size_t count) {
   return Fortran::runtime::memmove(dest, src, count);
 }
+[[maybe_unused]] static RT_API_ATTRS void *MemcpyWrapper(
+    void *dest, const void *src, std::size_t count) {
+  return Fortran::runtime::memcpy(dest, src, count);
+}
 #endif
 
 #if STD_STRLEN_USE_BUILT...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172913


More information about the llvm-commits mailing list