[flang-commits] [flang] 514c1ec - [flang][runtime] Interoperable POINTER deallocation validation (#96100)
via flang-commits
flang-commits at lists.llvm.org
Mon Jun 24 10:46:33 PDT 2024
Author: Peter Klausler
Date: 2024-06-24T10:46:30-07:00
New Revision: 514c1ec5477a48e4f639c0b15ab757832b67dd10
URL: https://github.com/llvm/llvm-project/commit/514c1ec5477a48e4f639c0b15ab757832b67dd10
DIFF: https://github.com/llvm/llvm-project/commit/514c1ec5477a48e4f639c0b15ab757832b67dd10.diff
LOG: [flang][runtime] Interoperable POINTER deallocation validation (#96100)
Extend the runtime validation of deallocated pointers so that it also
works when pointers are allocated &/or deallocated outside Fortran.
Previously, bogus runtime errors would be reported for pointers
allocated via CFI_allocate() and deallocated in Fortran, and
CFI_deallocate() did not check that it was deallocating a whole
contiguous pointer that was allocated as such.
Added:
Modified:
flang/include/flang/Runtime/pointer.h
flang/runtime/ISO_Fortran_binding.cpp
flang/runtime/descriptor.cpp
flang/runtime/pointer.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/pointer.h b/flang/include/flang/Runtime/pointer.h
index 6ceb70ebb676d..704144f08114f 100644
--- a/flang/include/flang/Runtime/pointer.h
+++ b/flang/include/flang/Runtime/pointer.h
@@ -115,6 +115,11 @@ bool RTDECL(PointerIsAssociated)(const Descriptor &);
bool RTDECL(PointerIsAssociatedWith)(
const Descriptor &, const Descriptor *target);
+// Fortran POINTERs are allocated with an extra validation word after their
+// payloads in order to detect erroneous deallocations later.
+RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t);
+RT_API_ATTRS bool ValidatePointerPayload(const ISO::CFI_cdesc_t &);
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_POINTER_H_
diff --git a/flang/runtime/ISO_Fortran_binding.cpp b/flang/runtime/ISO_Fortran_binding.cpp
index 99ba3aa56feee..fe22026f31f55 100644
--- a/flang/runtime/ISO_Fortran_binding.cpp
+++ b/flang/runtime/ISO_Fortran_binding.cpp
@@ -13,6 +13,7 @@
#include "terminator.h"
#include "flang/ISO_Fortran_binding_wrapper.h"
#include "flang/Runtime/descriptor.h"
+#include "flang/Runtime/pointer.h"
#include "flang/Runtime/type-code.h"
#include <cstdlib>
@@ -75,7 +76,7 @@ RT_API_ATTRS int CFI_allocate(CFI_cdesc_t *descriptor,
dim->sm = byteSize;
byteSize *= extent;
}
- void *p{byteSize ? std::malloc(byteSize) : std::malloc(1)};
+ void *p{runtime::AllocateValidatedPointerPayload(byteSize)};
if (!p && byteSize) {
return CFI_ERROR_MEM_ALLOCATION;
}
@@ -91,8 +92,11 @@ RT_API_ATTRS int CFI_deallocate(CFI_cdesc_t *descriptor) {
if (descriptor->version != CFI_VERSION) {
return CFI_INVALID_DESCRIPTOR;
}
- if (descriptor->attribute != CFI_attribute_allocatable &&
- descriptor->attribute != CFI_attribute_pointer) {
+ if (descriptor->attribute == CFI_attribute_pointer) {
+ if (!runtime::ValidatePointerPayload(*descriptor)) {
+ return CFI_INVALID_DESCRIPTOR;
+ }
+ } else if (descriptor->attribute != CFI_attribute_allocatable) {
// Non-interoperable object
return CFI_INVALID_DESCRIPTOR;
}
diff --git a/flang/runtime/descriptor.cpp b/flang/runtime/descriptor.cpp
index d8b51f1be0c5c..9b04cb4f8d0d0 100644
--- a/flang/runtime/descriptor.cpp
+++ b/flang/runtime/descriptor.cpp
@@ -199,7 +199,16 @@ RT_API_ATTRS int Descriptor::Destroy(
}
}
-RT_API_ATTRS int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); }
+RT_API_ATTRS int Descriptor::Deallocate() {
+ ISO::CFI_cdesc_t &descriptor{raw()};
+ if (!descriptor.base_addr) {
+ return CFI_ERROR_BASE_ADDR_NULL;
+ } else {
+ std::free(descriptor.base_addr);
+ descriptor.base_addr = nullptr;
+ return CFI_SUCCESS;
+ }
+}
RT_API_ATTRS bool Descriptor::DecrementSubscripts(
SubscriptValue *subscript, const int *permutation) const {
diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp
index 08a1223764f39..aeed879f1a2e2 100644
--- a/flang/runtime/pointer.cpp
+++ b/flang/runtime/pointer.cpp
@@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
}
}
+RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
+ // Add space for a footer to validate during deallocation.
+ constexpr std::size_t align{sizeof(std::uintptr_t)};
+ byteSize = ((byteSize / align) + 1) * align;
+ std::size_t total{byteSize + sizeof(std::uintptr_t)};
+ void *p{std::malloc(total)};
+ if (p) {
+ // Fill the footer word with the XOR of the ones' complement of
+ // the base address, which is a value that would be highly unlikely
+ // to appear accidentally at the right spot.
+ std::uintptr_t *footer{
+ reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
+ *footer = ~reinterpret_cast<std::uintptr_t>(p);
+ }
+ return p;
+}
+
int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
elementBytes = pointer.raw().elem_len = 0;
}
std::size_t byteSize{pointer.Elements() * elementBytes};
- // Add space for a footer to validate during DEALLOCATE.
- constexpr std::size_t align{sizeof(std::uintptr_t)};
- byteSize = ((byteSize + align - 1) / align) * align;
- std::size_t total{byteSize + sizeof(std::uintptr_t)};
- void *p{std::malloc(total)};
+ void *p{AllocateValidatedPointerPayload(byteSize)};
if (!p) {
return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat);
}
pointer.set_base_addr(p);
pointer.SetByteStrides();
- // Fill the footer word with the XOR of the ones' complement of
- // the base address, which is a value that would be highly unlikely
- // to appear accidentally at the right spot.
- std::uintptr_t *footer{
- reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
- *footer = ~reinterpret_cast<std::uintptr_t>(p);
int stat{StatOk};
if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
if (const auto *derived{addendum->derivedType()}) {
@@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source,
return stat;
}
+static RT_API_ATTRS std::size_t GetByteSize(
+ const ISO::CFI_cdesc_t &descriptor) {
+ std::size_t rank{descriptor.rank};
+ const ISO::CFI_dim_t *dim{descriptor.dim};
+ std::size_t byteSize{descriptor.elem_len};
+ for (std::size_t j{0}; j < rank; ++j) {
+ byteSize *= dim[j].extent;
+ }
+ return byteSize;
+}
+
+bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) {
+ std::size_t byteSize{GetByteSize(desc)};
+ constexpr std::size_t align{sizeof(std::uintptr_t)};
+ byteSize = ((byteSize / align) + 1) * align;
+ const void *p{desc.base_addr};
+ const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>(
+ static_cast<const char *>(p) + byteSize)};
+ return *footer == ~reinterpret_cast<std::uintptr_t>(p);
+}
+
int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
if (!pointer.IsAllocated()) {
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
}
- if (executionEnvironment.checkPointerDeallocation) {
- // Validate the footer. This should fail if the pointer doesn't
- // span the entire object, or the object was not allocated as a
- // pointer.
- std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()};
- constexpr std::size_t align{sizeof(std::uintptr_t)};
- byteSize = ((byteSize + align - 1) / align) * align;
- void *p{pointer.raw().base_addr};
- std::uintptr_t *footer{
- reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
- if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) {
- return ReturnError(
- terminator, StatBadPointerDeallocation, errMsg, hasStat);
- }
+ if (executionEnvironment.checkPointerDeallocation &&
+ !ValidatePointerPayload(pointer.raw())) {
+ return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
return ReturnError(terminator,
pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator),
More information about the flang-commits
mailing list