[llvm] [flang][runtime] Optimize Descriptor::FixedStride() (PR #151755)

Peter Klausler via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 16:55:41 PDT 2025


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/151755

>From 65480f5b70a70017b286d29a4467f3269e7f9882 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 30 Jul 2025 09:41:14 -0700
Subject: [PATCH] [flang][runtime] Optimize Descriptor::FixedStride()

Put the common cases on fast paths, and don't depend on IsContiguous()
in the general case path.  Add a unit test, too.
---
 .../include/flang-rt/runtime/descriptor.h     |  57 +++++--
 flang-rt/lib/runtime/derived.cpp              |  23 ++-
 flang-rt/lib/runtime/descriptor.cpp           |   9 +
 flang-rt/unittests/Runtime/CMakeLists.txt     |   1 +
 flang-rt/unittests/Runtime/Descriptor.cpp     | 160 ++++++++++++++++++
 5 files changed, 230 insertions(+), 20 deletions(-)
 create mode 100644 flang-rt/unittests/Runtime/Descriptor.cpp

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index bc5a5b5f14697..abd668e27f18f 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -181,6 +181,9 @@ class Descriptor {
       const SubscriptValue *extent = nullptr,
       ISO::CFI_attribute_t attribute = CFI_attribute_other);
 
+  RT_API_ATTRS void UncheckedScalarEstablish(
+      const typeInfo::DerivedType &, void *);
+
   // To create a descriptor for a derived type the caller
   // must provide non-null dt argument.
   // The addendum argument is only used for testing purposes,
@@ -433,7 +436,9 @@ class Descriptor {
     bool stridesAreContiguous{true};
     for (int j{0}; j < leadingDimensions; ++j) {
       const Dimension &dim{GetDimension(j)};
-      stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
+      if (bytes != dim.ByteStride() && dim.Extent() != 1) {
+        stridesAreContiguous = false;
+      }
       bytes *= dim.Extent();
     }
     // One and zero element arrays are contiguous even if the descriptor
@@ -448,28 +453,44 @@ class Descriptor {
 
   // The result, if any, is a fixed stride value that can be used to
   // address all elements.  It generalizes contiguity by also allowing
-  // the case of an array with extent 1 on all but one dimension.
+  // the case of an array with extent 1 on all dimensions but one.
+  // Returns 0 for an empty array, a byte stride if one is well-defined
+  // for the array, or nullopt otherwise.
   RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
-    auto rank{static_cast<std::size_t>(raw_.rank)};
-    common::optional<SubscriptValue> stride;
-    for (std::size_t j{0}; j < rank; ++j) {
-      const Dimension &dim{GetDimension(j)};
-      auto extent{dim.Extent()};
-      if (extent == 0) {
-        break; // empty array
-      } else if (extent == 1) { // ok
-      } else if (stride) {
-        // Extent > 1 on multiple dimensions
-        if (IsContiguous()) {
-          return ElementBytes();
+    int rank{raw_.rank};
+    auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
+    if (rank == 0) {
+      return elementBytes;
+    } else if (rank == 1) {
+      const Dimension &dim{GetDimension(0)};
+      return dim.Extent() == 0 ? 0 : dim.ByteStride();
+    } else {
+      common::optional<SubscriptValue> stride;
+      auto bytes{elementBytes};
+      for (int j{0}; j < rank; ++j) {
+        const Dimension &dim{GetDimension(j)};
+        auto extent{dim.Extent()};
+        if (extent == 0) {
+          return 0; // empty array
+        } else if (extent == 1) { // ok
         } else {
-          return common::nullopt;
+          if (stride) { // Extent > 1 on multiple dimensions
+            if (bytes != dim.ByteStride()) { // discontiguity
+              while (++j < rank) {
+                if (GetDimension(j).Extent() == 0) {
+                  return 0; // empty array
+                }
+              }
+              return common::nullopt; // nonempty, discontiguous
+            }
+          } else {
+            stride = dim.ByteStride();
+          }
+          bytes *= extent;
         }
-      } else {
-        stride = dim.ByteStride();
       }
+      return stride.value_or(elementBytes /*for singleton*/);
     }
-    return stride.value_or(0); // 0 for scalars and empty arrays
   }
 
   // Establishes a pointer to a section or element.
diff --git a/flang-rt/lib/runtime/derived.cpp b/flang-rt/lib/runtime/derived.cpp
index 4ed0baaa3d108..c3ef99df30769 100644
--- a/flang-rt/lib/runtime/derived.cpp
+++ b/flang-rt/lib/runtime/derived.cpp
@@ -360,6 +360,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
     } else if (component_->genre() == typeInfo::Component::Genre::Data &&
         component_->derivedType() &&
         !component_->derivedType()->noFinalizationNeeded()) {
+      // todo: calculate and use fixedStride_ here as in DestroyTicket to
+      // avoid subscripts and repeated descriptor establishment.
       SubscriptValue extents[maxRank];
       GetComponentExtents(extents, *component_, instance_);
       Descriptor &compDesc{componentDescriptor_.descriptor()};
@@ -452,6 +454,24 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
     } else if (component_->genre() == typeInfo::Component::Genre::Data) {
       if (!componentDerived || componentDerived->noDestructionNeeded()) {
         SkipToNextComponent();
+      } else if (fixedStride_) {
+        // faster path, no need for subscripts, can reuse descriptor
+        char *p{instance_.OffsetElement<char>(
+            elementAt_ * *fixedStride_ + component_->offset())};
+        Descriptor &compDesc{componentDescriptor_.descriptor()};
+        const typeInfo::DerivedType &compType{*componentDerived};
+        compDesc.UncheckedScalarEstablish(compType, p);
+        for (std::size_t j{elementAt_}; j < elements_;
+            ++j, p += *fixedStride_) {
+          compDesc.set_base_addr(p);
+          ++elementAt_;
+          if (int status{workQueue.BeginDestroy(
+                  compDesc, compType, /*finalize=*/false)};
+              status != StatOk) {
+            return status;
+          }
+        }
+        SkipToNextComponent();
       } else {
         SubscriptValue extents[maxRank];
         GetComponentExtents(extents, *component_, instance_);
@@ -461,8 +481,7 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
             instance_.ElementComponent<char>(subscripts_, component_->offset()),
             component_->rank(), extents);
         Advance();
-        if (int status{workQueue.BeginDestroy(
-                compDesc, *componentDerived, /*finalize=*/false)};
+        if (int status{workQueue.BeginDestroy(compDesc, compType, /*finalize=*/false)};
             status != StatOk) {
           return status;
         }
diff --git a/flang-rt/lib/runtime/descriptor.cpp b/flang-rt/lib/runtime/descriptor.cpp
index 021440cbdd0f6..fde4baa6a317c 100644
--- a/flang-rt/lib/runtime/descriptor.cpp
+++ b/flang-rt/lib/runtime/descriptor.cpp
@@ -100,6 +100,15 @@ RT_API_ATTRS void Descriptor::Establish(const typeInfo::DerivedType &dt,
   new (Addendum()) DescriptorAddendum{&dt};
 }
 
+RT_API_ATTRS void Descriptor::UncheckedScalarEstablish(
+    const typeInfo::DerivedType &dt, void *p) {
+  auto elementBytes{static_cast<std::size_t>(dt.sizeInBytes())};
+  ISO::EstablishDescriptor(
+      &raw_, p, CFI_attribute_other, CFI_type_struct, elementBytes, 0, nullptr);
+  SetHasAddendum();
+  new (Addendum()) DescriptorAddendum{&dt};
+}
+
 RT_API_ATTRS OwningPtr<Descriptor> Descriptor::Create(TypeCode t,
     std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent,
     ISO::CFI_attribute_t attribute, bool addendum,
diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt
index cf1e15ddfa3e7..e51bc24415773 100644
--- a/flang-rt/unittests/Runtime/CMakeLists.txt
+++ b/flang-rt/unittests/Runtime/CMakeLists.txt
@@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
   Complex.cpp
   CrashHandlerFixture.cpp
   Derived.cpp
+  Descriptor.cpp
   ExternalIOTest.cpp
   Format.cpp
   InputExtensions.cpp
diff --git a/flang-rt/unittests/Runtime/Descriptor.cpp b/flang-rt/unittests/Runtime/Descriptor.cpp
new file mode 100644
index 0000000000000..3a4a7670fc62e
--- /dev/null
+++ b/flang-rt/unittests/Runtime/Descriptor.cpp
@@ -0,0 +1,160 @@
+//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang-rt/runtime/descriptor.h"
+#include "tools.h"
+#include "gtest/gtest.h"
+
+using namespace Fortran::runtime;
+
+TEST(Descriptor, FixedStride) {
+  StaticDescriptor<4> staticDesc[2];
+  Descriptor &descriptor{staticDesc[0].descriptor()};
+  using Type = std::int32_t;
+  Type data[8][8][8];
+  constexpr int four{static_cast<int>(sizeof data[0][0][0])};
+  TypeCode integer{TypeCategory::Integer, four};
+  // Scalar
+  descriptor.Establish(integer, four, data, 0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Empty vector
+  SubscriptValue extent[3]{0, 0, 0};
+  descriptor.Establish(integer, four, data, 1, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous vector (0:7:1)
+  extent[0] = 8;
+  descriptor.Establish(integer, four, data, 1, extent);
+  ASSERT_EQ(descriptor.rank(), 1);
+  ASSERT_EQ(descriptor.Elements(), 8);
+  ASSERT_EQ(descriptor.ElementBytes(), four);
+  ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
+  ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
+  ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous reverse vector (7:0:-1)
+  descriptor.GetDimension(0).SetByteStride(-four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+  // Discontiguous vector (0:6:2)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+  // Empty matrix
+  extent[0] = 0;
+  descriptor.Establish(integer, four, data, 2, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous matrix (0:7, 0:7)
+  extent[0] = extent[1] = 8;
+  descriptor.Establish(integer, four, data, 2, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous row (0:7, 0)
+  descriptor.GetDimension(1).SetExtent(1);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous column (0, 0:7)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(1).SetExtent(7);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
+  // Contiguous reverse row (7:0:-1, 0)
+  descriptor.GetDimension(0).SetExtent(8);
+  descriptor.GetDimension(0).SetByteStride(-four);
+  descriptor.GetDimension(1).SetExtent(1);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+  // Contiguous reverse column (0, 7:0:-1)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(7);
+  descriptor.GetDimension(1).SetByteStride(8 * -four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
+  // Discontiguous row (0:6:2, 0)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(1);
+  descriptor.GetDimension(1).SetByteStride(four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+  // Discontiguous column (0, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
+  // Discontiguous reverse row (7:1:-2, 0)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(-2 * four);
+  descriptor.GetDimension(1).SetExtent(1);
+  descriptor.GetDimension(1).SetByteStride(four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
+  // Discontiguous reverse column (0, 7:1:-2)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
+  // Discontiguous rows (0:6:2, 0:1)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(2);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Discontiguous columns (0:1, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(2);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Empty 3-D array
+  extent[0] = extent[1] = extent[2] = 0;
+  ;
+  descriptor.Establish(integer, four, data, 3, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous 3-D array (0:7, 0:7, 0:7)
+  extent[0] = extent[1] = extent[2] = 8;
+  descriptor.Establish(integer, four, data, 3, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+  descriptor.GetDimension(2).SetExtent(4);
+  descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetExtent(8 * 2 * four);
+  descriptor.GetDimension(2).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+}



More information about the llvm-commits mailing list