[llvm] [flang][runtime] Optimize Descriptor::FixedStride() (PR #151755)
Peter Klausler via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 1 16:55:41 PDT 2025
https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/151755
>From 65480f5b70a70017b286d29a4467f3269e7f9882 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 30 Jul 2025 09:41:14 -0700
Subject: [PATCH] [flang][runtime] Optimize Descriptor::FixedStride()
Put the common cases on fast paths, and don't depend on IsContiguous()
in the general case path. Add a unit test, too.
---
.../include/flang-rt/runtime/descriptor.h | 57 +++++--
flang-rt/lib/runtime/derived.cpp | 23 ++-
flang-rt/lib/runtime/descriptor.cpp | 9 +
flang-rt/unittests/Runtime/CMakeLists.txt | 1 +
flang-rt/unittests/Runtime/Descriptor.cpp | 160 ++++++++++++++++++
5 files changed, 230 insertions(+), 20 deletions(-)
create mode 100644 flang-rt/unittests/Runtime/Descriptor.cpp
diff --git a/flang-rt/include/flang-rt/runtime/descriptor.h b/flang-rt/include/flang-rt/runtime/descriptor.h
index bc5a5b5f14697..abd668e27f18f 100644
--- a/flang-rt/include/flang-rt/runtime/descriptor.h
+++ b/flang-rt/include/flang-rt/runtime/descriptor.h
@@ -181,6 +181,9 @@ class Descriptor {
const SubscriptValue *extent = nullptr,
ISO::CFI_attribute_t attribute = CFI_attribute_other);
+ RT_API_ATTRS void UncheckedScalarEstablish(
+ const typeInfo::DerivedType &, void *);
+
// To create a descriptor for a derived type the caller
// must provide non-null dt argument.
// The addendum argument is only used for testing purposes,
@@ -433,7 +436,9 @@ class Descriptor {
bool stridesAreContiguous{true};
for (int j{0}; j < leadingDimensions; ++j) {
const Dimension &dim{GetDimension(j)};
- stridesAreContiguous &= bytes == dim.ByteStride() || dim.Extent() == 1;
+ if (bytes != dim.ByteStride() && dim.Extent() != 1) {
+ stridesAreContiguous = false;
+ }
bytes *= dim.Extent();
}
// One and zero element arrays are contiguous even if the descriptor
@@ -448,28 +453,44 @@ class Descriptor {
// The result, if any, is a fixed stride value that can be used to
// address all elements. It generalizes contiguity by also allowing
- // the case of an array with extent 1 on all but one dimension.
+ // the case of an array with extent 1 on all dimensions but one.
+ // Returns 0 for an empty array, a byte stride if one is well-defined
+ // for the array, or nullopt otherwise.
RT_API_ATTRS common::optional<SubscriptValue> FixedStride() const {
- auto rank{static_cast<std::size_t>(raw_.rank)};
- common::optional<SubscriptValue> stride;
- for (std::size_t j{0}; j < rank; ++j) {
- const Dimension &dim{GetDimension(j)};
- auto extent{dim.Extent()};
- if (extent == 0) {
- break; // empty array
- } else if (extent == 1) { // ok
- } else if (stride) {
- // Extent > 1 on multiple dimensions
- if (IsContiguous()) {
- return ElementBytes();
+ int rank{raw_.rank};
+ auto elementBytes{static_cast<SubscriptValue>(ElementBytes())};
+ if (rank == 0) {
+ return elementBytes;
+ } else if (rank == 1) {
+ const Dimension &dim{GetDimension(0)};
+ return dim.Extent() == 0 ? 0 : dim.ByteStride();
+ } else {
+ common::optional<SubscriptValue> stride;
+ auto bytes{elementBytes};
+ for (int j{0}; j < rank; ++j) {
+ const Dimension &dim{GetDimension(j)};
+ auto extent{dim.Extent()};
+ if (extent == 0) {
+ return 0; // empty array
+ } else if (extent == 1) { // ok
} else {
- return common::nullopt;
+ if (stride) { // Extent > 1 on multiple dimensions
+ if (bytes != dim.ByteStride()) { // discontiguity
+ while (++j < rank) {
+ if (GetDimension(j).Extent() == 0) {
+ return 0; // empty array
+ }
+ }
+ return common::nullopt; // nonempty, discontiguous
+ }
+ } else {
+ stride = dim.ByteStride();
+ }
+ bytes *= extent;
}
- } else {
- stride = dim.ByteStride();
}
+ return stride.value_or(elementBytes /*for singleton*/);
}
- return stride.value_or(0); // 0 for scalars and empty arrays
}
// Establishes a pointer to a section or element.
diff --git a/flang-rt/lib/runtime/derived.cpp b/flang-rt/lib/runtime/derived.cpp
index 4ed0baaa3d108..c3ef99df30769 100644
--- a/flang-rt/lib/runtime/derived.cpp
+++ b/flang-rt/lib/runtime/derived.cpp
@@ -360,6 +360,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
} else if (component_->genre() == typeInfo::Component::Genre::Data &&
component_->derivedType() &&
!component_->derivedType()->noFinalizationNeeded()) {
+ // todo: calculate and use fixedStride_ here as in DestroyTicket to
+ // avoid subscripts and repeated descriptor establishment.
SubscriptValue extents[maxRank];
GetComponentExtents(extents, *component_, instance_);
Descriptor &compDesc{componentDescriptor_.descriptor()};
@@ -452,6 +454,24 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
} else if (component_->genre() == typeInfo::Component::Genre::Data) {
if (!componentDerived || componentDerived->noDestructionNeeded()) {
SkipToNextComponent();
+ } else if (fixedStride_) {
+ // faster path, no need for subscripts, can reuse descriptor
+ char *p{instance_.OffsetElement<char>(
+ elementAt_ * *fixedStride_ + component_->offset())};
+ Descriptor &compDesc{componentDescriptor_.descriptor()};
+ const typeInfo::DerivedType &compType{*componentDerived};
+ compDesc.UncheckedScalarEstablish(compType, p);
+ for (std::size_t j{elementAt_}; j < elements_;
+ ++j, p += *fixedStride_) {
+ compDesc.set_base_addr(p);
+ ++elementAt_;
+ if (int status{workQueue.BeginDestroy(
+ compDesc, compType, /*finalize=*/false)};
+ status != StatOk) {
+ return status;
+ }
+ }
+ SkipToNextComponent();
} else {
SubscriptValue extents[maxRank];
GetComponentExtents(extents, *component_, instance_);
@@ -461,8 +481,7 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
instance_.ElementComponent<char>(subscripts_, component_->offset()),
component_->rank(), extents);
Advance();
- if (int status{workQueue.BeginDestroy(
- compDesc, *componentDerived, /*finalize=*/false)};
+ if (int status{workQueue.BeginDestroy(compDesc, compType, /*finalize=*/false)};
status != StatOk) {
return status;
}
diff --git a/flang-rt/lib/runtime/descriptor.cpp b/flang-rt/lib/runtime/descriptor.cpp
index 021440cbdd0f6..fde4baa6a317c 100644
--- a/flang-rt/lib/runtime/descriptor.cpp
+++ b/flang-rt/lib/runtime/descriptor.cpp
@@ -100,6 +100,15 @@ RT_API_ATTRS void Descriptor::Establish(const typeInfo::DerivedType &dt,
new (Addendum()) DescriptorAddendum{&dt};
}
+RT_API_ATTRS void Descriptor::UncheckedScalarEstablish(
+ const typeInfo::DerivedType &dt, void *p) {
+ auto elementBytes{static_cast<std::size_t>(dt.sizeInBytes())};
+ ISO::EstablishDescriptor(
+ &raw_, p, CFI_attribute_other, CFI_type_struct, elementBytes, 0, nullptr);
+ SetHasAddendum();
+ new (Addendum()) DescriptorAddendum{&dt};
+}
+
RT_API_ATTRS OwningPtr<Descriptor> Descriptor::Create(TypeCode t,
std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent,
ISO::CFI_attribute_t attribute, bool addendum,
diff --git a/flang-rt/unittests/Runtime/CMakeLists.txt b/flang-rt/unittests/Runtime/CMakeLists.txt
index cf1e15ddfa3e7..e51bc24415773 100644
--- a/flang-rt/unittests/Runtime/CMakeLists.txt
+++ b/flang-rt/unittests/Runtime/CMakeLists.txt
@@ -17,6 +17,7 @@ add_flangrt_unittest(RuntimeTests
Complex.cpp
CrashHandlerFixture.cpp
Derived.cpp
+ Descriptor.cpp
ExternalIOTest.cpp
Format.cpp
InputExtensions.cpp
diff --git a/flang-rt/unittests/Runtime/Descriptor.cpp b/flang-rt/unittests/Runtime/Descriptor.cpp
new file mode 100644
index 0000000000000..3a4a7670fc62e
--- /dev/null
+++ b/flang-rt/unittests/Runtime/Descriptor.cpp
@@ -0,0 +1,160 @@
+//===-- unittests/Runtime/Pointer.cpp ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang-rt/runtime/descriptor.h"
+#include "tools.h"
+#include "gtest/gtest.h"
+
+using namespace Fortran::runtime;
+
+TEST(Descriptor, FixedStride) {
+ StaticDescriptor<4> staticDesc[2];
+ Descriptor &descriptor{staticDesc[0].descriptor()};
+ using Type = std::int32_t;
+ Type data[8][8][8];
+ constexpr int four{static_cast<int>(sizeof data[0][0][0])};
+ TypeCode integer{TypeCategory::Integer, four};
+ // Scalar
+ descriptor.Establish(integer, four, data, 0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Empty vector
+ SubscriptValue extent[3]{0, 0, 0};
+ descriptor.Establish(integer, four, data, 1, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous vector (0:7:1)
+ extent[0] = 8;
+ descriptor.Establish(integer, four, data, 1, extent);
+ ASSERT_EQ(descriptor.rank(), 1);
+ ASSERT_EQ(descriptor.Elements(), 8);
+ ASSERT_EQ(descriptor.ElementBytes(), four);
+ ASSERT_EQ(descriptor.GetDimension(0).LowerBound(), 0);
+ ASSERT_EQ(descriptor.GetDimension(0).ByteStride(), four);
+ ASSERT_EQ(descriptor.GetDimension(0).Extent(), 8);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous reverse vector (7:0:-1)
+ descriptor.GetDimension(0).SetByteStride(-four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+ // Discontiguous vector (0:6:2)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+ // Empty matrix
+ extent[0] = 0;
+ descriptor.Establish(integer, four, data, 2, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous matrix (0:7, 0:7)
+ extent[0] = extent[1] = 8;
+ descriptor.Establish(integer, four, data, 2, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous row (0:7, 0)
+ descriptor.GetDimension(1).SetExtent(1);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Contiguous column (0, 0:7)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(1).SetExtent(7);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * four);
+ // Contiguous reverse row (7:0:-1, 0)
+ descriptor.GetDimension(0).SetExtent(8);
+ descriptor.GetDimension(0).SetByteStride(-four);
+ descriptor.GetDimension(1).SetExtent(1);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -four);
+ // Contiguous reverse column (0, 7:0:-1)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(7);
+ descriptor.GetDimension(1).SetByteStride(8 * -four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -four);
+ // Discontiguous row (0:6:2, 0)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(1);
+ descriptor.GetDimension(1).SetByteStride(four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 2 * four);
+ // Discontiguous column (0, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * 2 * four);
+ // Discontiguous reverse row (7:1:-2, 0)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(-2 * four);
+ descriptor.GetDimension(1).SetExtent(1);
+ descriptor.GetDimension(1).SetByteStride(four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), -2 * four);
+ // Discontiguous reverse column (0, 7:1:-2)
+ descriptor.GetDimension(0).SetExtent(1);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * -2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 8 * -2 * four);
+ // Discontiguous rows (0:6:2, 0:1)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(2);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Discontiguous columns (0:1, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(2);
+ descriptor.GetDimension(0).SetByteStride(four);
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Empty 3-D array
+ extent[0] = extent[1] = extent[2] = 0;
+ ;
+ descriptor.Establish(integer, four, data, 3, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Contiguous 3-D array (0:7, 0:7, 0:7)
+ extent[0] = extent[1] = extent[2] = 8;
+ descriptor.Establish(integer, four, data, 3, extent);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), four);
+ // Discontiguous 3-D array (0:7, 0:6:2, 0:6:2)
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetByteStride(8 * 2 * four);
+ descriptor.GetDimension(2).SetExtent(4);
+ descriptor.GetDimension(2).SetByteStride(8 * 8 * 2 * four);
+ EXPECT_FALSE(descriptor.IsContiguous());
+ EXPECT_FALSE(descriptor.FixedStride().has_value());
+ // Discontiguous-looking empty 3-D array (0:-1, 0:6:2, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Discontiguous-looking empty 3-D array (0:6:2, 0:-1, 0:6:2)
+ descriptor.GetDimension(0).SetExtent(4);
+ descriptor.GetDimension(0).SetByteStride(2 * four);
+ descriptor.GetDimension(1).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+ // Discontiguous-looking empty 3-D array (0:6:2, 0:6:2, 0:-1)
+ descriptor.GetDimension(1).SetExtent(4);
+ descriptor.GetDimension(1).SetExtent(8 * 2 * four);
+ descriptor.GetDimension(2).SetExtent(0);
+ EXPECT_TRUE(descriptor.IsContiguous());
+ EXPECT_EQ(descriptor.FixedStride().value_or(-666), 0);
+}
More information about the llvm-commits
mailing list