[llvm] [flang][runtime] Optimize Descriptor::FixedStride() (PR #151755)
Peter Klausler via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 1 12:29:49 PDT 2025
https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/151755
>From 44e79cf9b7a89ea7ef8e7bf100a86efca349e2a7 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 30 Jul 2025 09:41:14 -0700
Subject: [PATCH 1/2] wip
---
.../include/flang-rt/runtime/descriptor.h | 29 ++++++++++++-------
1 file changed, 19 insertions(+), 10 deletions(-)
diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index bc5a5b5f14697..ff0cb31f9946c 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -433,7 +433,9 @@ class Descriptor {
bool stridesAreContiguous{true};
for (int j{0}; j < leadingDimensions; ++j) {
const Dimension &dim{GetDimension(j)};
- stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
+ if (bytes != dim.ByteStride() && dim.Extent() != 1) {
+ stridesAreContiguous = false;
+ }
bytes *= dim.Extent();
}
// One and zero element arrays are contiguous even if the descriptor
@@ -448,25 +450,32 @@ class Descriptor {
// The result, if any, is a fixed stride value that can be used to
// address all elements. It generalizes contiguity by also allowing
- // the case of an array with extent 1 on all but one dimension.
+ // the case of an array with extent 1 on all dimensions but one.
RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
auto rank{static_cast<std::size_t>(raw_.rank)};
common::optional<SubscriptValue> stride;
+ auto bytes{static_cast<SubscriptValue>(ElementBytes())};
for (std::size_t j{0}; j < rank; ++j) {
const Dimension &dim{GetDimension(j)};
auto extent{dim.Extent()};
if (extent == 0) {
- break; // empty array
+ return 0; // empty array
} else if (extent == 1) { // ok
- } else if (stride) {
- // Extent > 1 on multiple dimensions
- if (IsContiguous()) {
- return ElementBytes();
+ } else {
+ if (stride) {
+ // Extent > 1 on multiple dimensions
+ if (bytes != dim.ByteStride()) {
+ while (++j < rank) {
+ if (dim.Extent() == 0) {
+ return 0; // empty array
+ }
+ }
+ return common::nullopt; // discontiguous
+ }
} else {
- return common::nullopt;
+ stride = dim.ByteStride();
}
- } else {
- stride = dim.ByteStride();
+ bytes *= extent;
}
}
return stride.value_or(0); // 0 for scalars and empty arrays
>From 71b302f500c70a88c58398fb9ecea8a9a4fcaa8f Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Fri, 1 Aug 2025 12:24:43 -0700
Subject: [PATCH 2/2] [flang][runtime] Optimize Descriptor::FixedStride()
Put the common cases on fast paths, and don't depend on IsContiguous()
in the general case path. Add a unit test, too.
---
.../include/flang-rt/runtime/descriptor.h | 51 +++---
flang-rt/unittests/Runtime/CMakeLists.txt | 1 +
flang-rt/unittests/Runtime/Descriptor.cpp | 160 ++++++++++++++++++
3 files changed, 191 insertions(+), 21 deletions(-)
create mode 100644 flang-rt/unittests/Runtime/Descriptor.cpp
diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index ff0cb31f9946c..fb17387b7dd45 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -451,34 +451,43 @@ class Descriptor {
// The result, if any, is a fixed stride value that can be used to
// address all elements. It generalizes contiguity by also allowing
// the case of an array with extent 1 on all dimensions but one.
+ // Returns 0 for an empty array, a byte stride if one is well-defined
+ // for the array, or nullopt otherwise.
RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
- auto rank{static_cast<std::size_t>(raw_.rank)};
- common::optional<SubscriptValue> stride;
- auto bytes{static_cast<SubscriptValue>(ElementBytes())};
- for (std::size_t j{0}; j < rank; ++j) {
- const Dimension &dim{GetDimension(j)};
- auto extent{dim.Extent()};
- if (extent == 0) {
- return 0; // empty array
- } else if (extent == 1) { // ok
- } else {
- if (stride) {
- // Extent > 1 on multiple dimensions
- if (bytes != dim.ByteStride()) {
- while (++j < rank) {
- if (dim.Extent() == 0) {
- return 0; // empty array
+ int rank{raw_.rank};
+ auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
+ if (rank == 0) {
+ return elementBytes;
+ } else if (rank == 1) {
+ const Dimension &dim{GetDimension(0)};
+ return dim.Extent() == 0 ? 0 : dim.ByteStride();
+ } else {
+ common::optional<SubscriptValue> stride;
+ auto bytes{elementBytes};
+ for (int j{0}; j < rank; ++j) {
+ const Dimension &dim{GetDimension(j)};
+ auto extent{dim.Extent()};
+ if (extent == 0) {
+ return 0; // empty array
+ } else if (extent == 1) { // ok
+ } else {
+ if (stride) { // Extent > 1 on multiple dimensions
+ if (bytes != dim.ByteStride()) { // discontiguity
+ while (++j < rank) {
+ if (GetDimension(j).Extent() == 0) {
+ return 0; // empty array
+ }
}
+ return common::nullopt; // nonempty, discontiguous
}
- return common::nullopt; // discontiguous
+ } else {
+ stride = dim.ByteStride();
}
- } else {
- stride = dim.ByteStride();
+ bytes *= extent;
}
- bytes *= extent;
}
+ return stride.value_or(elementBytes /*for singleton*/);
}
- return stride.value_or(0); // 0 for scalars and empty arrays
}
// Establishes a pointer to a section or element.
diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt
index cf1e15ddfa3e7..e51bc24415773 100644
--- a/flang-rt/unittests/Runtime/CMakeLists.txt
+++ b/flang-rt/unittests/Runtime/CMakeLists.txt
@@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
Complex.cpp
CrashHandlerFixture.cpp
Derived.cpp
+ Descriptor.cpp
ExternalIOTest.cpp
Format.cpp
InputExtensions.cpp
diff --git a/flang-rt/unittests/Runtime/Descriptor.cpp b/flang-rt/unittests/Runtime/Descriptor.cpp
new file mode 100644
index 0000000000000..3a4a7670fc62e
--- /dev/null
+++ b/flang-rt/unittests/Runtime/Descriptor.cpp
@@ -0,0 +1,160 @@
+//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang-rt/runtime/descriptor.h"
+#include "tools.h"
+#include "gtest/gtest.h"
+
+using namespace Fortran::runtime;
+
+TEST(Descriptor, FixedStride) {
+ StaticDescriptor<4> staticDesc[2];
+ Descriptor &descriptor{staticDesc[0].descriptor()};
+ using Type = std::int32_t;
+ Type data[8][8][8];
+ constexpr int four{static_cast<int>(sizeof data[0][0][0])};
+ TypeCode integer{TypeCategory::Integer, four};
+ // Scalar
+ descriptor.Establish(integer, four, data, 0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Empty vector
+ SubscriptValue extent[3]{0, 0, 0};
+ descriptor.Establish(integer, four, data, 1, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous vector (0:7:1)
+ extent[0] = 8;
+ descriptor.Establish(integer, four, data, 1, extent);
+ ASSERT_EQ(descriptor.rank(), 1);
+ ASSERT_EQ(descriptor.Elements(), 8);
+ ASSERT_EQ(descriptor.ElementBytes(), four);
+ ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
+ ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
+ ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous reverse vector (7:0:-1)
+ descriptor.GetDimension(0).SetByteStride(-four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+ // Discontiguous vector (0:6:2)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+ // Empty matrix
+ extent[0] = 0;
+ descriptor.Establish(integer, four, data, 2, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous matrix (0:7, 0:7)
+ extent[0] = extent[1] = 8;
+ descriptor.Establish(integer, four, data, 2, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous row (0:7, 0)
+ descriptor.GetDimension(1).SetExtent(1);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous column (0, 0:7)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(1).SetExtent(7);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
+ // Contiguous reverse row (7:0:-1, 0)
+ descriptor.GetDimension(0).SetExtent(8);
+ descriptor.GetDimension(0).SetByteStride(-four);
+ descriptor.GetDimension(1).SetExtent(1);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+ // Contiguous reverse column (0, 7:0:-1)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(7);
+ descriptor.GetDimension(1).SetByteStride(8 * -four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
+ // Discontiguous row (0:6:2, 0)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(1);
+ descriptor.GetDimension(1).SetByteStride(four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+ // Discontiguous column (0, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
+ // Discontiguous reverse row (7:1:-2, 0)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(-2 * four);
+ descriptor.GetDimension(1).SetExtent(1);
+ descriptor.GetDimension(1).SetByteStride(four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
+ // Discontiguous reverse column (0, 7:1:-2)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
+ // Discontiguous rows (0:6:2, 0:1)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(2);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Discontiguous columns (0:1, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(2);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Empty 3-D array
+ extent[0] = extent[1] = extent[2] = 0;
+ ;
+ descriptor.Establish(integer, four, data, 3, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous 3-D array (0:7, 0:7, 0:7)
+ extent[0] = extent[1] = extent[2] = 8;
+ descriptor.Establish(integer, four, data, 3, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+ descriptor.GetDimension(2).SetExtent(4);
+ descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetExtent(8 * 2 * four);
+ descriptor.GetDimension(2).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+}
More information about the llvm-commits
mailing list