[llvm] [flang-rt] Optimise ShallowCopy and use it in CopyInAssign (PR #140569)

Kajetan Puchalski via llvm-commits llvm-commits at lists.llvm.org
Thu May 22 05:00:24 PDT 2025


https://github.com/mrkajetanp updated https://github.com/llvm/llvm-project/pull/140569

>From 8f24bb50b4a7ca92adc3f76e62e61d125f6c38e4 Mon Sep 17 00:00:00 2001
From: Kajetan Puchalski <kajetan.puchalski at arm.com>
Date: Fri, 16 May 2025 18:34:22 +0000
Subject: [PATCH 1/5] [flang-rt] Optimise ShallowCopy and elemental copies in
 Assign

Using Descriptor.Element<>() when iterating through a rank-1 array is
currently inefficient, because the generic implementation suitable
for arrays of any rank makes the compiler unable to perform
optimisations that would make the rank-1 case considerably faster.

This is currently done inside ShallowCopy, as well as inside Assign
where the implementation of elemental copies is equivalent to
ShallowCopyDiscontiguousToDiscontiguous.

To address that, add a DescriptorIterator abstraction specialised both
for the optimised rank-1 case as well as for the generic case, and use
that throughout ShallowCopy to iterate over the arrays.

Furthermore, depending on the pointer type passed to memcpy, the
optimiser can remove the memcpy calls from ShallowCopy altogether which
can result in substantial performance improvements on its own. Check the
element size throughout ShallowCopy and use the pointer type that
matches it where applicable to make these optimisations possible.

Finally, replace the implementation of elemental copies inside Assign to
make use of the ShallowCopy* family of functions whenever possible.

For the thornado-mini application, this reduces the runtime by 27.7%.

Signed-off-by: Kajetan Puchalski <kajetan.puchalski at arm.com>
---
 .../include/flang-rt/runtime/descriptor.h     | 37 ++++++++
 flang-rt/include/flang-rt/runtime/tools.h     |  3 +
 flang-rt/lib/runtime/assign.cpp               | 20 +++--
 flang-rt/lib/runtime/tools.cpp                | 84 +++++++++++++++----
 4 files changed, 122 insertions(+), 22 deletions(-)

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index 9907e7866e7bf..6f60854584e30 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -437,6 +437,43 @@ class Descriptor {
 };
 static_assert(sizeof(Descriptor) == sizeof(ISO::CFI_cdesc_t));
 
+// Lightweight iterator-like API to simplify specialising Descriptor indexing
+// in cases where it can improve application performance. On account of the
+// purpose of this API being performance optimisation, it is up to the user to
+// do all the necessary checks to make sure the RANK1=true variant can be used
+// safely and that Advance() is not called more times than the number of
+// elements in the Descriptor allows for.
+template <bool RANK1 = false> class DescriptorIterator {
+private:
+  const Descriptor &descriptor;
+  SubscriptValue subscripts[maxRank];
+  std::size_t elementOffset = 0;
+
+public:
+  DescriptorIterator(const Descriptor &descriptor) : descriptor(descriptor) {
+    descriptor.GetLowerBounds(subscripts);
+    if constexpr (RANK1) {
+      elementOffset = descriptor.SubscriptByteOffset(0, subscripts[0]);
+    }
+  };
+
+  template <typename A> A *Get() {
+    if constexpr (RANK1) {
+      return descriptor.OffsetElement<A>(elementOffset);
+    } else {
+      return descriptor.Element<A>(subscripts);
+    }
+  }
+
+  void Advance() {
+    if constexpr (RANK1) {
+      elementOffset += descriptor.GetDimension(0).ByteStride();
+    } else {
+      descriptor.IncrementSubscripts(subscripts);
+    }
+  }
+};
+
 // Properly configured instances of StaticDescriptor will occupy the
 // exact amount of storage required for the descriptor, its dimensional
 // information, and possible addendum.  To build such a static descriptor,
diff --git a/flang-rt/include/flang-rt/runtime/tools.h b/flang-rt/include/flang-rt/runtime/tools.h
index 91a026bf2ac14..e0ebebcdc18da 100644
--- a/flang-rt/include/flang-rt/runtime/tools.h
+++ b/flang-rt/include/flang-rt/runtime/tools.h
@@ -511,10 +511,13 @@ inline RT_API_ATTRS const char *FindCharacter(
 // Copy payload data from one allocated descriptor to another.
 // Assumes element counts and element sizes match, and that both
 // descriptors are allocated.
+template <bool RANK1 = false>
 RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from);
+template <bool RANK1 = false>
 RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
     const Descriptor &to, const Descriptor &from);
+template <bool RANK1 = false>
 RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from);
 RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
diff --git a/flang-rt/lib/runtime/assign.cpp b/flang-rt/lib/runtime/assign.cpp
index 4a813cd489022..627db02ac3f1a 100644
--- a/flang-rt/lib/runtime/assign.cpp
+++ b/flang-rt/lib/runtime/assign.cpp
@@ -492,11 +492,21 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,
         terminator.Crash("unexpected type code %d in blank padded Assign()",
             to.type().raw());
       }
-    } else { // elemental copies, possibly with character truncation
-      for (std::size_t n{toElements}; n-- > 0;
-           to.IncrementSubscripts(toAt), from.IncrementSubscripts(fromAt)) {
-        memmoveFct(to.Element<char>(toAt), from.Element<const char>(fromAt),
-            toElementBytes);
+    } else {
+      // We can't simply call ShallowCopy due to edge cases such as character
+      // truncation or assignments where the RHS is a scalar.
+      if (toElementBytes == fromElementBytes && to.IsContiguous()) {
+        if (to.rank() == 1 && from.rank() == 1) {
+          ShallowCopyDiscontiguousToContiguous<true>(to, from);
+        } else {
+          ShallowCopyDiscontiguousToContiguous<false>(to, from);
+        }
+      } else {
+        if (to.rank() == 1 && from.rank() == 1) {
+          ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
+        } else {
+          ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
+        }
       }
     }
   }
diff --git a/flang-rt/lib/runtime/tools.cpp b/flang-rt/lib/runtime/tools.cpp
index 5d6e35faca70a..13811bf38fbb3 100644
--- a/flang-rt/lib/runtime/tools.cpp
+++ b/flang-rt/lib/runtime/tools.cpp
@@ -114,40 +114,78 @@ RT_API_ATTRS void CheckIntegerKind(
   }
 }
 
+template <bool RANK1>
 RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from) {
-  SubscriptValue toAt[maxRank], fromAt[maxRank];
-  to.GetLowerBounds(toAt);
-  from.GetLowerBounds(fromAt);
+  DescriptorIterator<RANK1> toIt{to};
+  DescriptorIterator<RANK1> fromIt{from};
   std::size_t elementBytes{to.ElementBytes()};
   for (std::size_t n{to.Elements()}; n-- > 0;
-       to.IncrementSubscripts(toAt), from.IncrementSubscripts(fromAt)) {
-    std::memcpy(
-        to.Element<char>(toAt), from.Element<char>(fromAt), elementBytes);
+      toIt.Advance(), fromIt.Advance()) {
+    // Checking the size at runtime and making sure the pointer passed to memcpy
+    // has a type that matches the element size makes it possible for the
+    // compiler to optimise out the memcpy calls altogether and can
+    // substantially improve performance for some applications.
+    if (elementBytes == 16) {
+      std::memcpy(toIt.template Get<__int128_t>(),
+          fromIt.template Get<__int128_t>(), elementBytes);
+    } else if (elementBytes == 8) {
+      std::memcpy(toIt.template Get<int64_t>(), fromIt.template Get<int64_t>(),
+          elementBytes);
+    } else if (elementBytes == 4) {
+      std::memcpy(toIt.template Get<int32_t>(), fromIt.template Get<int32_t>(),
+          elementBytes);
+    } else if (elementBytes == 2) {
+      std::memcpy(toIt.template Get<int16_t>(), fromIt.template Get<int16_t>(),
+          elementBytes);
+    } else {
+      std::memcpy(
+          toIt.template Get<char>(), fromIt.template Get<char>(), elementBytes);
+    }
   }
 }
 
+template <bool RANK1>
 RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
     const Descriptor &to, const Descriptor &from) {
   char *toAt{to.OffsetElement()};
-  SubscriptValue fromAt[maxRank];
-  from.GetLowerBounds(fromAt);
   std::size_t elementBytes{to.ElementBytes()};
+  DescriptorIterator<RANK1> fromIt{from};
   for (std::size_t n{to.Elements()}; n-- > 0;
-       toAt += elementBytes, from.IncrementSubscripts(fromAt)) {
-    std::memcpy(toAt, from.Element<char>(fromAt), elementBytes);
+      toAt += elementBytes, fromIt.Advance()) {
+    if (elementBytes == 16) {
+      std::memcpy(toAt, fromIt.template Get<__int128_t>(), elementBytes);
+    } else if (elementBytes == 8) {
+      std::memcpy(toAt, fromIt.template Get<int64_t>(), elementBytes);
+    } else if (elementBytes == 4) {
+      std::memcpy(toAt, fromIt.template Get<int32_t>(), elementBytes);
+    } else if (elementBytes == 2) {
+      std::memcpy(toAt, fromIt.template Get<int16_t>(), elementBytes);
+    } else {
+      std::memcpy(toAt, fromIt.template Get<char>(), elementBytes);
+    }
   }
 }
 
+template <bool RANK1>
 RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from) {
-  SubscriptValue toAt[maxRank];
-  to.GetLowerBounds(toAt);
   char *fromAt{from.OffsetElement()};
+  DescriptorIterator<RANK1> toIt{to};
   std::size_t elementBytes{to.ElementBytes()};
   for (std::size_t n{to.Elements()}; n-- > 0;
-       to.IncrementSubscripts(toAt), fromAt += elementBytes) {
-    std::memcpy(to.Element<char>(toAt), fromAt, elementBytes);
+      toIt.Advance(), fromAt += elementBytes) {
+    if (elementBytes == 16) {
+      std::memcpy(toIt.template Get<__int128_t>(), fromAt, elementBytes);
+    } else if (elementBytes == 8) {
+      std::memcpy(toIt.template Get<int64_t>(), fromAt, elementBytes);
+    } else if (elementBytes == 4) {
+      std::memcpy(toIt.template Get<int32_t>(), fromAt, elementBytes);
+    } else if (elementBytes == 2) {
+      std::memcpy(toIt.template Get<int16_t>(), fromAt, elementBytes);
+    } else {
+      std::memcpy(toIt.template Get<char>(), fromAt, elementBytes);
+    }
   }
 }
 
@@ -158,13 +196,25 @@ RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
       std::memcpy(to.OffsetElement(), from.OffsetElement(),
           to.Elements() * to.ElementBytes());
     } else {
-      ShallowCopyDiscontiguousToContiguous(to, from);
+      if (to.rank() == 1 && from.rank() == 1) {
+        ShallowCopyDiscontiguousToContiguous<true>(to, from);
+      } else {
+        ShallowCopyDiscontiguousToContiguous<false>(to, from);
+      }
     }
   } else {
     if (fromIsContiguous) {
-      ShallowCopyContiguousToDiscontiguous(to, from);
+      if (to.rank() == 1 && from.rank() == 1) {
+        ShallowCopyContiguousToDiscontiguous<true>(to, from);
+      } else {
+        ShallowCopyContiguousToDiscontiguous<false>(to, from);
+      }
     } else {
-      ShallowCopyDiscontiguousToDiscontiguous(to, from);
+      if (to.rank() == 1 && from.rank() == 1) {
+        ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
+      } else {
+        ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
+      }
     }
   }
 }

>From 79a353e3b810bde85bfbd8140a8daa4ec16a0f32 Mon Sep 17 00:00:00 2001
From: Kajetan Puchalski <kajetan.puchalski at arm.com>
Date: Tue, 20 May 2025 14:49:14 +0000
Subject: [PATCH 2/5] Rework in line with Slava's review comments

---
 .../include/flang-rt/runtime/descriptor.h     |  41 +++--
 flang-rt/include/flang-rt/runtime/tools.h     |   6 +-
 flang-rt/lib/runtime/assign.cpp               |  30 ++--
 flang-rt/lib/runtime/tools.cpp                | 151 +++++++++++-------
 4 files changed, 138 insertions(+), 90 deletions(-)

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index 6f60854584e30..622e1970607d8 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -440,34 +440,55 @@ static_assert(sizeof(Descriptor) == sizeof(ISO::CFI_cdesc_t));
 // Lightweight iterator-like API to simplify specialising Descriptor indexing
 // in cases where it can improve application performance. On account of the
 // purpose of this API being performance optimisation, it is up to the user to
-// do all the necessary checks to make sure the RANK1=true variant can be used
+// do all the necessary checks to make sure the specialised variants can be used
 // safely and that Advance() is not called more times than the number of
 // elements in the Descriptor allows for.
-template <bool RANK1 = false> class DescriptorIterator {
+// Default RANK=-1 supports aray descriptors of any rank up to maxRank.
+template <int RANK = -1> class DescriptorIterator {
 private:
   const Descriptor &descriptor;
   SubscriptValue subscripts[maxRank];
   std::size_t elementOffset = 0;
 
 public:
-  DescriptorIterator(const Descriptor &descriptor) : descriptor(descriptor) {
+  RT_API_ATTRS DescriptorIterator(const Descriptor &descriptor)
+      : descriptor(descriptor) {
     descriptor.GetLowerBounds(subscripts);
-    if constexpr (RANK1) {
+    if constexpr (RANK == 1) {
       elementOffset = descriptor.SubscriptByteOffset(0, subscripts[0]);
     }
   };
 
-  template <typename A> A *Get() {
-    if constexpr (RANK1) {
-      return descriptor.OffsetElement<A>(elementOffset);
+  template <typename A> RT_API_ATTRS A *Get() {
+    std::size_t offset = 0;
+    // The rank-1 case doesn't require looping at all
+    if constexpr (RANK == 1) {
+      offset = elementOffset;
+      // The compiler might be able to optimise this better if we know the rank
+      // at compile time
+    } else if (RANK != -1) {
+      for (int j{0}; j < RANK; ++j) {
+        offset += descriptor.SubscriptByteOffset(j, subscripts[j]);
+      }
+      // General fallback
     } else {
-      return descriptor.Element<A>(subscripts);
+      offset = descriptor.SubscriptsToByteOffset(subscripts);
     }
+
+    return descriptor.OffsetElement<A>(offset);
   }
 
-  void Advance() {
-    if constexpr (RANK1) {
+  RT_API_ATTRS void Advance() {
+    if constexpr (RANK == 1) {
       elementOffset += descriptor.GetDimension(0).ByteStride();
+    } else if (RANK != -1) {
+      for (int j{0}; j < RANK; ++j) {
+        const Dimension &dim{descriptor.GetDimension(j)};
+        if (subscripts[j]++ < dim.UpperBound()) {
+          break;
+        }
+        subscripts[j] = dim.LowerBound();
+      }
     } else {
       descriptor.IncrementSubscripts(subscripts);
     }
diff --git a/flang-rt/include/flang-rt/runtime/tools.h b/flang-rt/include/flang-rt/runtime/tools.h
index e0ebebcdc18da..a1b96f41f4936 100644
--- a/flang-rt/include/flang-rt/runtime/tools.h
+++ b/flang-rt/include/flang-rt/runtime/tools.h
@@ -511,13 +511,13 @@ inline RT_API_ATTRS const char *FindCharacter(
 // Copy payload data from one allocated descriptor to another.
 // Assumes element counts and element sizes match, and that both
 // descriptors are allocated.
-template <bool RANK1 = false>
+template <typename P = char, int RANK = -1>
 RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from);
-template <bool RANK1 = false>
+template <typename P = char, int RANK = -1>
 RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
     const Descriptor &to, const Descriptor &from);
-template <bool RANK1 = false>
+template <typename P = char, int RANK = -1>
 RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from);
 RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
diff --git a/flang-rt/lib/runtime/assign.cpp b/flang-rt/lib/runtime/assign.cpp
index 627db02ac3f1a..9f4dcfa7d86c1 100644
--- a/flang-rt/lib/runtime/assign.cpp
+++ b/flang-rt/lib/runtime/assign.cpp
@@ -492,21 +492,11 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,
         terminator.Crash("unexpected type code %d in blank padded Assign()",
             to.type().raw());
       }
-    } else {
-      // We can't simply call ShallowCopy due to edge cases such as character
-      // truncation or assignments where the RHS is a scalar.
-      if (toElementBytes == fromElementBytes && to.IsContiguous()) {
-        if (to.rank() == 1 && from.rank() == 1) {
-          ShallowCopyDiscontiguousToContiguous<true>(to, from);
-        } else {
-          ShallowCopyDiscontiguousToContiguous<false>(to, from);
-        }
-      } else {
-        if (to.rank() == 1 && from.rank() == 1) {
-          ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
-        } else {
-          ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
-        }
+    } else { // elemental copies, possibly with character truncation
+      for (std::size_t n{toElements}; n-- > 0;
+          to.IncrementSubscripts(toAt), from.IncrementSubscripts(fromAt)) {
+        memmoveFct(to.Element<char>(toAt), from.Element<const char>(fromAt),
+            toElementBytes);
       }
     }
   }
@@ -598,7 +588,8 @@ void RTDEF(CopyInAssign)(Descriptor &temp, const Descriptor &var,
   temp = var;
   temp.set_base_addr(nullptr);
   temp.raw().attribute = CFI_attribute_allocatable;
-  RTNAME(AssignTemporary)(temp, var, sourceFile, sourceLine);
+  temp.Allocate(kNoAsyncId);
+  ShallowCopy(temp, var);
 }
 
 void RTDEF(CopyOutAssign)(
@@ -607,9 +598,10 @@ void RTDEF(CopyOutAssign)(
 
   // Copyout from the temporary must not cause any finalizations
   // for LHS. The variable must be properly initialized already.
-  if (var)
-    Assign(*var, temp, terminator, NoAssignFlags);
-  temp.Destroy(/*finalize=*/false, /*destroyPointers=*/false, &terminator);
+  if (var) {
+    ShallowCopy(*var, temp);
+  }
+  temp.Deallocate();
 }
 
 void RTDEF(AssignExplicitLengthCharacter)(Descriptor &to,
diff --git a/flang-rt/lib/runtime/tools.cpp b/flang-rt/lib/runtime/tools.cpp
index 13811bf38fbb3..2387ce612fc5e 100644
--- a/flang-rt/lib/runtime/tools.cpp
+++ b/flang-rt/lib/runtime/tools.cpp
@@ -114,111 +114,146 @@ RT_API_ATTRS void CheckIntegerKind(
   }
 }
 
-template <bool RANK1>
+template <typename P, int RANK>
 RT_API_ATTRS void ShallowCopyDiscontiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from) {
-  DescriptorIterator<RANK1> toIt{to};
-  DescriptorIterator<RANK1> fromIt{from};
+  DescriptorIterator<RANK> toIt{to};
+  DescriptorIterator<RANK> fromIt{from};
+  // Knowing the size at compile time can enable memcpy inlining optimisations
+  constexpr std::size_t typeElementBytes{sizeof(P)};
+  // We might still need to check the actual size as a fallback
   std::size_t elementBytes{to.ElementBytes()};
   for (std::size_t n{to.Elements()}; n-- > 0;
       toIt.Advance(), fromIt.Advance()) {
-    // Checking the size at runtime and making sure the pointer passed to memcpy
-    // has a type that matches the element size makes it possible for the
-    // compiler to optimise out the memcpy calls altogether and can
-    // substantially improve performance for some applications.
-    if (elementBytes == 16) {
-      std::memcpy(toIt.template Get<__int128_t>(),
-          fromIt.template Get<__int128_t>(), elementBytes);
-    } else if (elementBytes == 8) {
-      std::memcpy(toIt.template Get<int64_t>(), fromIt.template Get<int64_t>(),
-          elementBytes);
-    } else if (elementBytes == 4) {
-      std::memcpy(toIt.template Get<int32_t>(), fromIt.template Get<int32_t>(),
-          elementBytes);
-    } else if (elementBytes == 2) {
-      std::memcpy(toIt.template Get<int16_t>(), fromIt.template Get<int16_t>(),
-          elementBytes);
+    // typeElementBytes == 1 when P is a char - the non-specialised case
+    if constexpr (typeElementBytes != 1) {
+      std::memcpy(
+          toIt.template Get<P>(), fromIt.template Get<P>(), typeElementBytes);
     } else {
       std::memcpy(
-          toIt.template Get<char>(), fromIt.template Get<char>(), elementBytes);
+          toIt.template Get<P>(), fromIt.template Get<P>(), elementBytes);
     }
   }
 }
 
-template <bool RANK1>
+template <typename P, int RANK>
 RT_API_ATTRS void ShallowCopyDiscontiguousToContiguous(
     const Descriptor &to, const Descriptor &from) {
   char *toAt{to.OffsetElement()};
+  constexpr std::size_t typeElementBytes{sizeof(P)};
   std::size_t elementBytes{to.ElementBytes()};
-  DescriptorIterator<RANK1> fromIt{from};
+  DescriptorIterator<RANK> fromIt{from};
   for (std::size_t n{to.Elements()}; n-- > 0;
       toAt += elementBytes, fromIt.Advance()) {
-    if (elementBytes == 16) {
-      std::memcpy(toAt, fromIt.template Get<__int128_t>(), elementBytes);
-    } else if (elementBytes == 8) {
-      std::memcpy(toAt, fromIt.template Get<int64_t>(), elementBytes);
-    } else if (elementBytes == 4) {
-      std::memcpy(toAt, fromIt.template Get<int32_t>(), elementBytes);
-    } else if (elementBytes == 2) {
-      std::memcpy(toAt, fromIt.template Get<int16_t>(), elementBytes);
+    if constexpr (typeElementBytes != 1) {
+      std::memcpy(toAt, fromIt.template Get<P>(), typeElementBytes);
     } else {
-      std::memcpy(toAt, fromIt.template Get<char>(), elementBytes);
+      std::memcpy(toAt, fromIt.template Get<P>(), elementBytes);
     }
   }
 }
 
-template <bool RANK1>
+template <typename P, int RANK>
 RT_API_ATTRS void ShallowCopyContiguousToDiscontiguous(
     const Descriptor &to, const Descriptor &from) {
   char *fromAt{from.OffsetElement()};
-  DescriptorIterator<RANK1> toIt{to};
+  DescriptorIterator<RANK> toIt{to};
+  constexpr std::size_t typeElementBytes{sizeof(P)};
   std::size_t elementBytes{to.ElementBytes()};
   for (std::size_t n{to.Elements()}; n-- > 0;
       toIt.Advance(), fromAt += elementBytes) {
-    if (elementBytes == 16) {
-      std::memcpy(toIt.template Get<__int128_t>(), fromAt, elementBytes);
-    } else if (elementBytes == 8) {
-      std::memcpy(toIt.template Get<int64_t>(), fromAt, elementBytes);
-    } else if (elementBytes == 4) {
-      std::memcpy(toIt.template Get<int32_t>(), fromAt, elementBytes);
-    } else if (elementBytes == 2) {
-      std::memcpy(toIt.template Get<int16_t>(), fromAt, elementBytes);
+    if constexpr (typeElementBytes != 1) {
+      std::memcpy(toIt.template Get<P>(), fromAt, typeElementBytes);
     } else {
-      std::memcpy(toIt.template Get<char>(), fromAt, elementBytes);
+      std::memcpy(toIt.template Get<P>(), fromAt, elementBytes);
     }
   }
 }
 
-RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
+// ShallowCopy helper for calling the correct specialised variant based on
+// scenario
+template <typename P, int RANK = -1>
+RT_API_ATTRS void ShallowCopyInner(const Descriptor &to, const Descriptor &from,
     bool toIsContiguous, bool fromIsContiguous) {
   if (toIsContiguous) {
     if (fromIsContiguous) {
       std::memcpy(to.OffsetElement(), from.OffsetElement(),
           to.Elements() * to.ElementBytes());
     } else {
-      if (to.rank() == 1 && from.rank() == 1) {
-        ShallowCopyDiscontiguousToContiguous<true>(to, from);
-      } else {
-        ShallowCopyDiscontiguousToContiguous<false>(to, from);
-      }
+      ShallowCopyDiscontiguousToContiguous<P, RANK>(to, from);
     }
   } else {
     if (fromIsContiguous) {
-      if (to.rank() == 1 && from.rank() == 1) {
-        ShallowCopyContiguousToDiscontiguous<true>(to, from);
-      } else {
-        ShallowCopyContiguousToDiscontiguous<false>(to, from);
-      }
+      ShallowCopyContiguousToDiscontiguous<P, RANK>(to, from);
     } else {
-      if (to.rank() == 1 && from.rank() == 1) {
-        ShallowCopyDiscontiguousToDiscontiguous<true>(to, from);
-      } else {
-        ShallowCopyDiscontiguousToDiscontiguous<false>(to, from);
-      }
+      ShallowCopyDiscontiguousToDiscontiguous<P, RANK>(to, from);
     }
   }
 }
 
+// ShallowCopy helper for specialising the variants based on array rank
+template <typename P>
+RT_API_ATTRS void ShallowCopyRank(const Descriptor &to, const Descriptor &from,
+    bool toIsContiguous, bool fromIsContiguous) {
+  if (to.rank() == 1 && from.rank() == 1) {
+    ShallowCopyInner<P, 1>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 2 && from.rank() == 2) {
+    ShallowCopyInner<P, 2>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 3 && from.rank() == 3) {
+    ShallowCopyInner<P, 3>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 4 && from.rank() == 4) {
+    ShallowCopyInner<P, 4>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 5 && from.rank() == 5) {
+    ShallowCopyInner<P, 5>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 6 && from.rank() == 6) {
+    ShallowCopyInner<P, 6>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 7 && from.rank() == 7) {
+    ShallowCopyInner<P, 7>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 8 && from.rank() == 8) {
+    ShallowCopyInner<P, 8>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 9 && from.rank() == 9) {
+    ShallowCopyInner<P, 9>(to, from, toIsContiguous, fromIsContiguous);
+  } else if (to.rank() == 10 && from.rank() == 10) {
+    ShallowCopyInner<P, 10>(to, from, toIsContiguous, fromIsContiguous);
+  } else {
+    ShallowCopyInner<P>(to, from, toIsContiguous, fromIsContiguous);
+  }
+}
+
+RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from,
+    bool toIsContiguous, bool fromIsContiguous) {
+  std::size_t elementBytes{to.ElementBytes()};
+  // Checking the type at runtime and making sure the pointer passed to memcpy
+  // has a type that matches the element type makes it possible for the compiler
+  // to optimise out the memcpy calls altogether and can substantially improve
+  // performance for some applications.
+  if (to.type().IsInteger()) {
+    if (elementBytes == sizeof(int64_t)) {
+      ShallowCopyRank<int64_t>(to, from, toIsContiguous, fromIsContiguous);
+    } else if (elementBytes == sizeof(int32_t)) {
+      ShallowCopyRank<int32_t>(to, from, toIsContiguous, fromIsContiguous);
+    } else if (elementBytes == sizeof(int16_t)) {
+      ShallowCopyRank<int16_t>(to, from, toIsContiguous, fromIsContiguous);
+#if defined USING_NATIVE_INT128_T
+    } else if (elementBytes == sizeof(__int128_t)) {
+      ShallowCopyRank<__int128_t>(to, from, toIsContiguous, fromIsContiguous);
+#endif
+    } else {
+      ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
+    }
+  } else if (to.type().IsReal()) {
+    if (elementBytes == sizeof(double)) {
+      ShallowCopyRank<double>(to, from, toIsContiguous, fromIsContiguous);
+    } else if (elementBytes == sizeof(float)) {
+      ShallowCopyRank<float>(to, from, toIsContiguous, fromIsContiguous);
+    } else {
+      ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
+    }
+  } else {
+    ShallowCopyRank<char>(to, from, toIsContiguous, fromIsContiguous);
+  }
+}
+
 RT_API_ATTRS void ShallowCopy(const Descriptor &to, const Descriptor &from) {
   ShallowCopy(to, from, to.IsContiguous(), from.IsContiguous());
 }

>From e3991f813f04cfafece887eeaea955362d8fbfc6 Mon Sep 17 00:00:00 2001
From: Kajetan Puchalski <kajetan.puchalski at arm.com>
Date: Wed, 21 May 2025 11:38:45 +0000
Subject: [PATCH 3/5] Review v2 - Recursive template and missing constexprs

---
 .../include/flang-rt/runtime/descriptor.h     | 10 ++--
 flang-rt/lib/runtime/tools.cpp                | 47 ++++++++++---------
 2 files changed, 31 insertions(+), 26 deletions(-)

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index 622e1970607d8..544ca8c84b7e5 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -453,9 +453,9 @@ template <int RANK = -1> class DescriptorIterator {
 public:
   RT_API_ATTRS DescriptorIterator(const Descriptor &descriptor)
       : descriptor(descriptor) {
-    descriptor.GetLowerBounds(subscripts);
-    if constexpr (RANK == 1) {
-      elementOffset = descriptor.SubscriptByteOffset(0, subscripts[0]);
+    // We do not need the subscripts to iterate over a rank-1 array
+    if constexpr (RANK != 1) {
+      descriptor.GetLowerBounds(subscripts);
     }
   };
 
@@ -466,7 +466,7 @@ template <int RANK = -1> class DescriptorIterator {
       offset = elementOffset;
       // The compiler might be able to optimise this better if we know the rank
       // at compile time
-    } else if (RANK != -1) {
+    } else if constexpr (RANK != -1) {
       for (int j{0}; j < RANK; ++j) {
         offset += descriptor.SubscriptByteOffset(j, subscripts[j]);
       }
@@ -481,7 +481,7 @@ template <int RANK = -1> class DescriptorIterator {
   RT_API_ATTRS void Advance() {
     if constexpr (RANK == 1) {
       elementOffset += descriptor.GetDimension(0).ByteStride();
-    } else if (RANK != -1) {
+    } else if constexpr (RANK != -1) {
       for (int j{0}; j < RANK; ++j) {
         const Dimension &dim{descriptor.GetDimension(j)};
         if (subscripts[j]++ < dim.UpperBound()) {
diff --git a/flang-rt/lib/runtime/tools.cpp b/flang-rt/lib/runtime/tools.cpp
index 2387ce612fc5e..59f6905aaefd1 100644
--- a/flang-rt/lib/runtime/tools.cpp
+++ b/flang-rt/lib/runtime/tools.cpp
@@ -191,31 +191,36 @@ RT_API_ATTRS void ShallowCopyInner(const Descriptor &to, const Descriptor &from,
   }
 }
 
+// Most arrays are much closer to rank-1 than to maxRank.
+// Doing the recursion upwards instead of downwards puts the more common
+// cases earlier in the if-chain and has a tangible impact on performance.
+template <typename P, int RANK> struct ShallowCopyRankSpecialize {
+  static bool execute(const Descriptor &to, const Descriptor &from,
+      bool toIsContiguous, bool fromIsContiguous) {
+    if (to.rank() == RANK && from.rank() == RANK) {
+      ShallowCopyInner<P, RANK>(to, from, toIsContiguous, fromIsContiguous);
+      return true;
+    }
+    return ShallowCopyRankSpecialize<P, RANK + 1>::execute(
+        to, from, toIsContiguous, fromIsContiguous);
+  }
+};
+
+template <typename P> struct ShallowCopyRankSpecialize<P, maxRank + 1> {
+  static bool execute(const Descriptor &to, const Descriptor &from,
+      bool toIsContiguous, bool fromIsContiguous) {
+    return false;
+  }
+};
+
 // ShallowCopy helper for specialising the variants based on array rank
 template <typename P>
 RT_API_ATTRS void ShallowCopyRank(const Descriptor &to, const Descriptor &from,
     bool toIsContiguous, bool fromIsContiguous) {
-  if (to.rank() == 1 && from.rank() == 1) {
-    ShallowCopyInner<P, 1>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 2 && from.rank() == 2) {
-    ShallowCopyInner<P, 2>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 3 && from.rank() == 3) {
-    ShallowCopyInner<P, 3>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 4 && from.rank() == 4) {
-    ShallowCopyInner<P, 4>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 5 && from.rank() == 5) {
-    ShallowCopyInner<P, 5>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 6 && from.rank() == 6) {
-    ShallowCopyInner<P, 6>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 7 && from.rank() == 7) {
-    ShallowCopyInner<P, 7>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 8 && from.rank() == 8) {
-    ShallowCopyInner<P, 8>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 9 && from.rank() == 9) {
-    ShallowCopyInner<P, 9>(to, from, toIsContiguous, fromIsContiguous);
-  } else if (to.rank() == 10 && from.rank() == 10) {
-    ShallowCopyInner<P, 10>(to, from, toIsContiguous, fromIsContiguous);
-  } else {
+  // Try to call a specialised ShallowCopy variant from rank-1 up to maxRank
+  bool specialized = ShallowCopyRankSpecialize<P, 1>::execute(
+      to, from, toIsContiguous, fromIsContiguous);
+  if (!specialized) {
     ShallowCopyInner<P>(to, from, toIsContiguous, fromIsContiguous);
   }
 }

>From 61f5e134487dc55d188af91a6c872dda520e0751 Mon Sep 17 00:00:00 2001
From: Kajetan Puchalski <kajetan.puchalski at arm.com>
Date: Wed, 21 May 2025 12:56:11 +0000
Subject: [PATCH 4/5] Add unit test for CopyInAssign

---
 flang-rt/unittests/Runtime/Assign.cpp     | 55 +++++++++++++++++++++++
 flang-rt/unittests/Runtime/CMakeLists.txt |  1 +
 2 files changed, 56 insertions(+)
 create mode 100644 flang-rt/unittests/Runtime/Assign.cpp

diff --git a/flang-rt/unittests/Runtime/Assign.cpp b/flang-rt/unittests/Runtime/Assign.cpp
new file mode 100644
index 0000000000000..4001cc90ca0a1
--- /dev/null
+++ b/flang-rt/unittests/Runtime/Assign.cpp
@@ -0,0 +1,55 @@
+//===-- unittests/Runtime/Assign.cpp ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Runtime/assign.h"
+#include "tools.h"
+#include "gtest/gtest.h"
+#include <vector>
+
+using namespace Fortran::runtime;
+using Fortran::common::TypeCategory;
+
+TEST(Assign, RTNAME(CopyInAssign)) {
+  // contiguous -> contiguous copy in
+  auto intArray{MakeArray<TypeCategory::Integer, 1>(
+      std::vector<int>{2, 3}, std::vector<int>{1, 2, 3, 4, 5, 6}, sizeof(int))};
+  StaticDescriptor<2> staticIntResult;
+  Descriptor &intResult{staticIntResult.descriptor()};
+
+  RTNAME(CopyInAssign(intResult, *intArray));
+  ASSERT_TRUE(intResult.IsAllocated());
+  ASSERT_TRUE(intResult.IsContiguous());
+  ASSERT_EQ(intResult.type(), intArray->type());
+  ASSERT_EQ(intResult.ElementBytes(), sizeof(int));
+  EXPECT_EQ(intResult.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(intResult.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(intResult.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(intResult.GetDimension(1).Extent(), 3);
+  int expected[6] = {1, 2, 3, 4, 5, 6};
+  EXPECT_EQ(
+      std::memcmp(intResult.OffsetElement<int>(0), expected, 6 * sizeof(int)),
+      0);
+  intResult.Destroy();
+
+  // discontiguous -> contiguous rank-1 copy in
+  intArray = MakeArray<TypeCategory::Integer, 1>(std::vector<int>{8},
+      std::vector<int>{1, 2, 3, 4, 5, 6, 7, 8}, sizeof(int));
+  StaticDescriptor<1> staticIntResultStrided;
+  Descriptor &intResultStrided{staticIntResultStrided.descriptor()};
+  // Treat the descriptor as a strided array of 4
+  intArray->GetDimension(0).SetByteStride(sizeof(int) * 2);
+  intArray->GetDimension(0).SetExtent(4);
+  RTNAME(CopyInAssign(intResultStrided, *intArray));
+
+  int expectedStrided[4] = {1, 3, 5, 7};
+  EXPECT_EQ(std::memcmp(intResultStrided.OffsetElement<int>(0), expectedStrided,
+                4 * sizeof(int)),
+      0);
+
+  intResultStrided.Destroy();
+}
diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt
index 61d0aba93b14b..49f55a442863b 100644
--- a/flang-rt/unittests/Runtime/CMakeLists.txt
+++ b/flang-rt/unittests/Runtime/CMakeLists.txt
@@ -10,6 +10,7 @@ add_flangrt_unittest(RuntimeTests
   AccessTest.cpp
   Allocatable.cpp
   ArrayConstructor.cpp
+  Assign.cpp
   BufferTest.cpp
   CharacterTest.cpp
   CommandTest.cpp

>From ea1f476dfe5b298695401a71a6f666d9fd623c39 Mon Sep 17 00:00:00 2001
From: Kajetan Puchalski <kajetan.puchalski at tuta.io>
Date: Thu, 22 May 2025 13:00:14 +0100
Subject: [PATCH 5/5] Braced initializers

Co-authored-by: Yusuke MINATO <minato.yusuke at fujitsu.com>
---
 flang-rt/include/flang-rt/runtime/descriptor.h | 4 ++--
 flang-rt/lib/runtime/tools.cpp                 | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index 544ca8c84b7e5..aa6ec1dbdebea 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -448,7 +448,7 @@ template <int RANK = -1> class DescriptorIterator {
 private:
   const Descriptor &descriptor;
   SubscriptValue subscripts[maxRank];
-  std::size_t elementOffset = 0;
+  std::size_t elementOffset{0};
 
 public:
   RT_API_ATTRS DescriptorIterator(const Descriptor &descriptor)
@@ -460,7 +460,7 @@ template <int RANK = -1> class DescriptorIterator {
   };
 
   template <typename A> RT_API_ATTRS A *Get() {
-    std::size_t offset = 0;
+    std::size_t offset{0};
     // The rank-1 case doesn't require looping at all
     if constexpr (RANK == 1) {
       offset = elementOffset;
diff --git a/flang-rt/lib/runtime/tools.cpp b/flang-rt/lib/runtime/tools.cpp
index 59f6905aaefd1..e13d0fe10a63a 100644
--- a/flang-rt/lib/runtime/tools.cpp
+++ b/flang-rt/lib/runtime/tools.cpp
@@ -218,8 +218,8 @@ template <typename P>
 RT_API_ATTRS void ShallowCopyRank(const Descriptor &to, const Descriptor &from,
     bool toIsContiguous, bool fromIsContiguous) {
   // Try to call a specialised ShallowCopy variant from rank-1 up to maxRank
-  bool specialized = ShallowCopyRankSpecialize<P, 1>::execute(
-      to, from, toIsContiguous, fromIsContiguous);
+  bool specialized{ShallowCopyRankSpecialize<P, 1>::execute(
+      to, from, toIsContiguous, fromIsContiguous)};
   if (!specialized) {
     ShallowCopyInner<P>(to, from, toIsContiguous, fromIsContiguous);
   }



More information about the llvm-commits mailing list