[llvm] [flang][runtime] Optimize Descriptor::FixedStride() (PR #151755)

Peter Klausler via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 12:29:49 PDT 2025


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/151755

>From 44e79cf9b7a89ea7ef8e7bf100a86efca349e2a7 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 30 Jul 2025 09:41:14 -0700
Subject: [PATCH 1/2] wip

---
 .../include/flang-rt/runtime/descriptor.h     | 29 ++++++++++++-------
 1 file changed, 19 insertions(+), 10 deletions(-)

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index bc5a5b5f14697..ff0cb31f9946c 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -433,7 +433,9 @@ class Descriptor {
     bool stridesAreContiguous{true};
     for (int j{0}; j < leadingDimensions; ++j) {
       const Dimension &dim{GetDimension(j)};
-      stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
+      if (bytes != dim.ByteStride() && dim.Extent() != 1) {
+        stridesAreContiguous = false;
+      }
       bytes *= dim.Extent();
     }
     // One and zero element arrays are contiguous even if the descriptor
@@ -448,25 +450,32 @@ class Descriptor {
 
   // The result, if any, is a fixed stride value that can be used to
   // address all elements.  It generalizes contiguity by also allowing
-  // the case of an array with extent 1 on all but one dimension.
+  // the case of an array with extent 1 on all dimensions but one.
   RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
     auto rank{static_cast<std::size_t>(raw_.rank)};
     common::optional<SubscriptValue> stride;
+    auto bytes{static_cast<SubscriptValue>(ElementBytes())};
     for (std::size_t j{0}; j < rank; ++j) {
       const Dimension &dim{GetDimension(j)};
       auto extent{dim.Extent()};
       if (extent == 0) {
-        break; // empty array
+        return 0; // empty array
       } else if (extent == 1) { // ok
-      } else if (stride) {
-        // Extent > 1 on multiple dimensions
-        if (IsContiguous()) {
-          return ElementBytes();
+      } else {
+        if (stride) {
+          // Extent > 1 on multiple dimensions
+          if (bytes != dim.ByteStride()) {
+            while (++j < rank) {
+              if (dim.Extent() == 0) {
+                return 0; // empty array
+              }
+            }
+            return common::nullopt; // discontiguous
+          }
         } else {
-          return common::nullopt;
+          stride = dim.ByteStride();
         }
-      } else {
-        stride = dim.ByteStride();
+        bytes *= extent;
       }
     }
     return stride.value_or(0); // 0 for scalars and empty arrays

>From 71b302f500c70a88c58398fb9ecea8a9a4fcaa8f Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Fri, 1 Aug 2025 12:24:43 -0700
Subject: [PATCH 2/2] [flang][runtime] Optimize Descriptor::FixedStride()

Put the common cases on fast paths, and don't depend on IsContiguous()
in the general case path.  Add a unit test, too.
---
 .../include/flang-rt/runtime/descriptor.h     |  51 +++---
 flang-rt/unittests/Runtime/CMakeLists.txt     |   1 +
 flang-rt/unittests/Runtime/Descriptor.cpp     | 160 ++++++++++++++++++
 3 files changed, 191 insertions(+), 21 deletions(-)
 create mode 100644 flang-rt/unittests/Runtime/Descriptor.cpp

diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index ff0cb31f9946c..fb17387b7dd45 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -451,34 +451,43 @@ class Descriptor {
   // The result, if any, is a fixed stride value that can be used to
   // address all elements.  It generalizes contiguity by also allowing
   // the case of an array with extent 1 on all dimensions but one.
+  // Returns 0 for an empty array, a byte stride if one is well-defined
+  // for the array, or nullopt otherwise.
   RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
-    auto rank{static_cast<std::size_t>(raw_.rank)};
-    common::optional<SubscriptValue> stride;
-    auto bytes{static_cast<SubscriptValue>(ElementBytes())};
-    for (std::size_t j{0}; j < rank; ++j) {
-      const Dimension &dim{GetDimension(j)};
-      auto extent{dim.Extent()};
-      if (extent == 0) {
-        return 0; // empty array
-      } else if (extent == 1) { // ok
-      } else {
-        if (stride) {
-          // Extent > 1 on multiple dimensions
-          if (bytes != dim.ByteStride()) {
-            while (++j < rank) {
-              if (dim.Extent() == 0) {
-                return 0; // empty array
+    int rank{raw_.rank};
+    auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
+    if (rank == 0) {
+      return elementBytes;
+    } else if (rank == 1) {
+      const Dimension &dim{GetDimension(0)};
+      return dim.Extent() == 0 ? 0 : dim.ByteStride();
+    } else {
+      common::optional<SubscriptValue> stride;
+      auto bytes{elementBytes};
+      for (int j{0}; j < rank; ++j) {
+        const Dimension &dim{GetDimension(j)};
+        auto extent{dim.Extent()};
+        if (extent == 0) {
+          return 0; // empty array
+        } else if (extent == 1) { // ok
+        } else {
+          if (stride) { // Extent > 1 on multiple dimensions
+            if (bytes != dim.ByteStride()) { // discontiguity
+              while (++j < rank) {
+                if (GetDimension(j).Extent() == 0) {
+                  return 0; // empty array
+                }
               }
+              return common::nullopt; // nonempty, discontiguous
             }
-            return common::nullopt; // discontiguous
+          } else {
+            stride = dim.ByteStride();
           }
-        } else {
-          stride = dim.ByteStride();
+          bytes *= extent;
         }
-        bytes *= extent;
       }
+      return stride.value_or(elementBytes /*for singleton*/);
     }
-    return stride.value_or(0); // 0 for scalars and empty arrays
   }
 
   // Establishes a pointer to a section or element.
diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt
index cf1e15ddfa3e7..e51bc24415773 100644
--- a/flang-rt/unittests/Runtime/CMakeLists.txt
+++ b/flang-rt/unittests/Runtime/CMakeLists.txt
@@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
   Complex.cpp
   CrashHandlerFixture.cpp
   Derived.cpp
+  Descriptor.cpp
   ExternalIOTest.cpp
   Format.cpp
   InputExtensions.cpp
diff --git a/flang-rt/unittests/Runtime/Descriptor.cpp b/flang-rt/unittests/Runtime/Descriptor.cpp
new file mode 100644
index 0000000000000..3a4a7670fc62e
--- /dev/null
+++ b/flang-rt/unittests/Runtime/Descriptor.cpp
@@ -0,0 +1,160 @@
+//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang-rt/runtime/descriptor.h"
+#include "tools.h"
+#include "gtest/gtest.h"
+
+using namespace Fortran::runtime;
+
+TEST(Descriptor, FixedStride) {
+  StaticDescriptor<4> staticDesc[2];
+  Descriptor &descriptor{staticDesc[0].descriptor()};
+  using Type = std::int32_t;
+  Type data[8][8][8];
+  constexpr int four{static_cast<int>(sizeof data[0][0][0])};
+  TypeCode integer{TypeCategory::Integer, four};
+  // Scalar
+  descriptor.Establish(integer, four, data, 0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Empty vector
+  SubscriptValue extent[3]{0, 0, 0};
+  descriptor.Establish(integer, four, data, 1, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous vector (0:7:1)
+  extent[0] = 8;
+  descriptor.Establish(integer, four, data, 1, extent);
+  ASSERT_EQ(descriptor.rank(), 1);
+  ASSERT_EQ(descriptor.Elements(), 8);
+  ASSERT_EQ(descriptor.ElementBytes(), four);
+  ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
+  ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
+  ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous reverse vector (7:0:-1)
+  descriptor.GetDimension(0).SetByteStride(-four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+  // Discontiguous vector (0:6:2)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+  // Empty matrix
+  extent[0] = 0;
+  descriptor.Establish(integer, four, data, 2, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous matrix (0:7, 0:7)
+  extent[0] = extent[1] = 8;
+  descriptor.Establish(integer, four, data, 2, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous row (0:7, 0)
+  descriptor.GetDimension(1).SetExtent(1);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Contiguous column (0, 0:7)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(1).SetExtent(7);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
+  // Contiguous reverse row (7:0:-1, 0)
+  descriptor.GetDimension(0).SetExtent(8);
+  descriptor.GetDimension(0).SetByteStride(-four);
+  descriptor.GetDimension(1).SetExtent(1);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+  // Contiguous reverse column (0, 7:0:-1)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(7);
+  descriptor.GetDimension(1).SetByteStride(8 * -four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
+  // Discontiguous row (0:6:2, 0)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(1);
+  descriptor.GetDimension(1).SetByteStride(four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+  // Discontiguous column (0, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
+  // Discontiguous reverse row (7:1:-2, 0)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(-2 * four);
+  descriptor.GetDimension(1).SetExtent(1);
+  descriptor.GetDimension(1).SetByteStride(four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
+  // Discontiguous reverse column (0, 7:1:-2)
+  descriptor.GetDimension(0).SetExtent(1);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
+  // Discontiguous rows (0:6:2, 0:1)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(2);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Discontiguous columns (0:1, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(2);
+  descriptor.GetDimension(0).SetByteStride(four);
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Empty 3-D array
+  extent[0] = extent[1] = extent[2] = 0;
+  ;
+  descriptor.Establish(integer, four, data, 3, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Contiguous 3-D array (0:7, 0:7, 0:7)
+  extent[0] = extent[1] = extent[2] = 8;
+  descriptor.Establish(integer, four, data, 3, extent);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+  // Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+  descriptor.GetDimension(2).SetExtent(4);
+  descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
+  EXPECT_FALSE(descriptor.IsContiguous());
+  EXPECT_FALSE(descriptor.FixedStride().has_value());
+  // Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
+  descriptor.GetDimension(0).SetExtent(4);
+  descriptor.GetDimension(0).SetByteStride(2 * four);
+  descriptor.GetDimension(1).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+  // Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
+  descriptor.GetDimension(1).SetExtent(4);
+  descriptor.GetDimension(1).SetExtent(8 * 2 * four);
+  descriptor.GetDimension(2).SetExtent(0);
+  EXPECT_TRUE(descriptor.IsContiguous());
+  EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+}



More information about the llvm-commits mailing list