[flang-commits] [flang] [flang] Fixed write past allocated descriptor in PointerAssociateRemapping. (PR #127000)
via flang-commits
flang-commits at lists.llvm.org
Wed Feb 12 19:10:17 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
The pointer descriptor might be smaller than the target descriptor,
so `operator=` would write beyound the pointer descriptor.
---
Full diff: https://github.com/llvm/llvm-project/pull/127000.diff
2 Files Affected:
- (modified) flang/runtime/pointer.cpp (+7-10)
- (modified) flang/unittests/Runtime/Pointer.cpp (+44)
``````````diff
diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp
index 2979181ddd61b..3b0babe3d852f 100644
--- a/flang/runtime/pointer.cpp
+++ b/flang/runtime/pointer.cpp
@@ -89,14 +89,18 @@ void RTDEF(PointerAssociateLowerBounds)(Descriptor &pointer,
void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
const Descriptor &target, const Descriptor &bounds, const char *sourceFile,
int sourceLine) {
- pointer = target;
- pointer.raw().attribute = CFI_attribute_pointer;
Terminator terminator{sourceFile, sourceLine};
SubscriptValue byteStride{/*captured from first dimension*/};
std::size_t boundElementBytes{bounds.ElementBytes()};
std::size_t boundsRank{
static_cast<std::size_t>(bounds.GetDimension(1).Extent())};
- pointer.raw().rank = boundsRank;
+ // We cannot just assign target into pointer descriptor, because
+ // the ranks may mismatch. Use target as a mold for initializing
+ // the pointer descriptor.
+ INTERNAL_CHECK(static_cast<std::size_t>(pointer.rank()) == boundsRank);
+ pointer.ApplyMold(target, boundsRank);
+ pointer.set_base_addr(target.raw().base_addr);
+ pointer.raw().attribute = CFI_attribute_pointer;
for (unsigned j{0}; j < boundsRank; ++j) {
auto &dim{pointer.GetDimension(j)};
dim.SetBounds(GetInt64(bounds.ZeroBasedIndexedElement<const char>(2 * j),
@@ -115,13 +119,6 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
"pointer (%zd > %zd)",
pointer.Elements(), target.Elements());
}
- if (auto *pointerAddendum{pointer.Addendum()}) {
- if (const auto *targetAddendum{target.Addendum()}) {
- if (const auto *derived{targetAddendum->derivedType()}) {
- pointerAddendum->set_derivedType(derived);
- }
- }
- }
}
RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
diff --git a/flang/unittests/Runtime/Pointer.cpp b/flang/unittests/Runtime/Pointer.cpp
index 4ce13ebc50a56..54720afab8d8a 100644
--- a/flang/unittests/Runtime/Pointer.cpp
+++ b/flang/unittests/Runtime/Pointer.cpp
@@ -105,3 +105,47 @@ TEST(Pointer, AllocateSourceZeroSize) {
EXPECT_EQ(p->GetDimension(0).UpperBound(), 0);
p->Destroy();
}
+
+TEST(Pointer, PointerAssociateRemapping) {
+ using Fortran::common::TypeCategory;
+ // REAL(4), POINTER :: p(:)
+ StaticDescriptor<Fortran::common::maxRank, true> staticDesc;
+ auto p{staticDesc.descriptor()};
+ SubscriptValue extent[1]{1};
+ p.Establish(TypeCode{Fortran::common::TypeCategory::Real, 4}, 4, nullptr, 1,
+ extent, CFI_attribute_pointer);
+ std::size_t descSize{p.SizeInBytes()};
+ EXPECT_LE(descSize, staticDesc.byteSize);
+ // REAL(4), CONTIGUOUS, POINTER :: t(:,:,:)
+ auto t{Descriptor::Create(TypeCode{Fortran::common::TypeCategory::Real, 4}, 4,
+ nullptr, 3, nullptr, CFI_attribute_pointer)};
+ RTNAME(PointerSetBounds)(*t, 0, 1, 1);
+ RTNAME(PointerSetBounds)(*t, 1, 1, 1);
+ RTNAME(PointerSetBounds)(*t, 2, 1, 1);
+ RTNAME(PointerAllocate)(
+ *t, /*hasStat=*/false, /*errMsg=*/nullptr, __FILE__, __LINE__);
+ EXPECT_TRUE(RTNAME(PointerIsAssociated)(*t));
+ // INTEGER(4) :: b(2,1) = [[1,1]]
+ auto b{MakeArray<TypeCategory::Integer, 4>(
+ std::vector<int>{2, 1}, std::vector<std::int32_t>{1, 1})};
+ // p(1:1) => t
+ RTNAME(PointerAssociateRemapping)(p, *t, *b, __FILE__, __LINE__);
+ EXPECT_TRUE(RTNAME(PointerIsAssociated)(p));
+ EXPECT_EQ(p.rank(), 1);
+ EXPECT_EQ(p.Elements(), 1u);
+
+ // Verify that the memory past the p's descriptor is not affected.
+ const char *addr = reinterpret_cast<const char *>(&staticDesc);
+ const char *ptr = addr + descSize;
+ const char *end = addr + staticDesc.byteSize;
+ while (ptr != end) {
+ if (*ptr != '\0') {
+ std::fprintf(stderr, "byte %zd after pointer descriptor was written\n",
+ ptr - addr);
+ EXPECT_EQ(*ptr, '\0');
+ break;
+ }
+ ++ptr;
+ }
+ p.Destroy();
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/127000
More information about the flang-commits
mailing list