[flang-commits] [flang] [flang][runtime] Interoperable POINTER deallocation validation (PR #96100)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 19 11:17:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-runtime

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

Extend the runtime validation of deallocated pointers so that it also works when pointers are allocated &/or deallocated outside Fortran. Previously, bogus runtime errors would be reported for pointers allocated via CFI_allocate() and deallocated in Fortran, and CFI_deallocate() did not check that it was deallocating a whole contiguous pointer that was allocated as such.

---
Full diff: https://github.com/llvm/llvm-project/pull/96100.diff


4 Files Affected:

- (modified) flang/include/flang/Runtime/pointer.h (+5) 
- (modified) flang/runtime/ISO_Fortran_binding.cpp (+7-3) 
- (modified) flang/runtime/descriptor.cpp (+10-1) 
- (modified) flang/runtime/pointer.cpp (+42-25) 


``````````diff
diff --git a/flang/include/flang/Runtime/pointer.h b/flang/include/flang/Runtime/pointer.h
index 6ceb70ebb676d..704144f08114f 100644
--- a/flang/include/flang/Runtime/pointer.h
+++ b/flang/include/flang/Runtime/pointer.h
@@ -115,6 +115,11 @@ bool RTDECL(PointerIsAssociated)(const Descriptor &);
 bool RTDECL(PointerIsAssociatedWith)(
     const Descriptor &, const Descriptor *target);
 
+// Fortran POINTERs are allocated with an extra validation word after their
+// payloads in order to detect erroneous deallocations later.
+RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t);
+RT_API_ATTRS bool ValidatePointerPayload(const ISO::CFI_cdesc_t &);
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_POINTER_H_
diff --git a/flang/runtime/ISO_Fortran_binding.cpp b/flang/runtime/ISO_Fortran_binding.cpp
index 99ba3aa56feee..fe22026f31f55 100644
--- a/flang/runtime/ISO_Fortran_binding.cpp
+++ b/flang/runtime/ISO_Fortran_binding.cpp
@@ -13,6 +13,7 @@
 #include "terminator.h"
 #include "flang/ISO_Fortran_binding_wrapper.h"
 #include "flang/Runtime/descriptor.h"
+#include "flang/Runtime/pointer.h"
 #include "flang/Runtime/type-code.h"
 #include <cstdlib>
 
@@ -75,7 +76,7 @@ RT_API_ATTRS int CFI_allocate(CFI_cdesc_t *descriptor,
     dim->sm = byteSize;
     byteSize *= extent;
   }
-  void *p{byteSize ? std::malloc(byteSize) : std::malloc(1)};
+  void *p{runtime::AllocateValidatedPointerPayload(byteSize)};
   if (!p && byteSize) {
     return CFI_ERROR_MEM_ALLOCATION;
   }
@@ -91,8 +92,11 @@ RT_API_ATTRS int CFI_deallocate(CFI_cdesc_t *descriptor) {
   if (descriptor->version != CFI_VERSION) {
     return CFI_INVALID_DESCRIPTOR;
   }
-  if (descriptor->attribute != CFI_attribute_allocatable &&
-      descriptor->attribute != CFI_attribute_pointer) {
+  if (descriptor->attribute == CFI_attribute_pointer) {
+    if (!runtime::ValidatePointerPayload(*descriptor)) {
+      return CFI_INVALID_DESCRIPTOR;
+    }
+  } else if (descriptor->attribute != CFI_attribute_allocatable) {
     // Non-interoperable object
     return CFI_INVALID_DESCRIPTOR;
   }
diff --git a/flang/runtime/descriptor.cpp b/flang/runtime/descriptor.cpp
index d8b51f1be0c5c..9b04cb4f8d0d0 100644
--- a/flang/runtime/descriptor.cpp
+++ b/flang/runtime/descriptor.cpp
@@ -199,7 +199,16 @@ RT_API_ATTRS int Descriptor::Destroy(
   }
 }
 
-RT_API_ATTRS int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); }
+RT_API_ATTRS int Descriptor::Deallocate() {
+  ISO::CFI_cdesc_t &descriptor{raw()};
+  if (!descriptor.base_addr) {
+    return CFI_ERROR_BASE_ADDR_NULL;
+  } else {
+    std::free(descriptor.base_addr);
+    descriptor.base_addr = nullptr;
+    return CFI_SUCCESS;
+  }
+}
 
 RT_API_ATTRS bool Descriptor::DecrementSubscripts(
     SubscriptValue *subscript, const int *permutation) const {
diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp
index 08a1223764f39..aeed879f1a2e2 100644
--- a/flang/runtime/pointer.cpp
+++ b/flang/runtime/pointer.cpp
@@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
   }
 }
 
+RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
+  // Add space for a footer to validate during deallocation.
+  constexpr std::size_t align{sizeof(std::uintptr_t)};
+  byteSize = ((byteSize / align) + 1) * align;
+  std::size_t total{byteSize + sizeof(std::uintptr_t)};
+  void *p{std::malloc(total)};
+  if (p) {
+    // Fill the footer word with the XOR of the ones' complement of
+    // the base address, which is a value that would be highly unlikely
+    // to appear accidentally at the right spot.
+    std::uintptr_t *footer{
+        reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
+    *footer = ~reinterpret_cast<std::uintptr_t>(p);
+  }
+  return p;
+}
+
 int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
     const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
@@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
     elementBytes = pointer.raw().elem_len = 0;
   }
   std::size_t byteSize{pointer.Elements() * elementBytes};
-  // Add space for a footer to validate during DEALLOCATE.
-  constexpr std::size_t align{sizeof(std::uintptr_t)};
-  byteSize = ((byteSize + align - 1) / align) * align;
-  std::size_t total{byteSize + sizeof(std::uintptr_t)};
-  void *p{std::malloc(total)};
+  void *p{AllocateValidatedPointerPayload(byteSize)};
   if (!p) {
     return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat);
   }
   pointer.set_base_addr(p);
   pointer.SetByteStrides();
-  // Fill the footer word with the XOR of the ones' complement of
-  // the base address, which is a value that would be highly unlikely
-  // to appear accidentally at the right spot.
-  std::uintptr_t *footer{
-      reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
-  *footer = ~reinterpret_cast<std::uintptr_t>(p);
   int stat{StatOk};
   if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
     if (const auto *derived{addendum->derivedType()}) {
@@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source,
   return stat;
 }
 
+static RT_API_ATTRS std::size_t GetByteSize(
+    const ISO::CFI_cdesc_t &descriptor) {
+  std::size_t rank{descriptor.rank};
+  const ISO::CFI_dim_t *dim{descriptor.dim};
+  std::size_t byteSize{descriptor.elem_len};
+  for (std::size_t j{0}; j < rank; ++j) {
+    byteSize *= dim[j].extent;
+  }
+  return byteSize;
+}
+
+bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) {
+  std::size_t byteSize{GetByteSize(desc)};
+  constexpr std::size_t align{sizeof(std::uintptr_t)};
+  byteSize = ((byteSize / align) + 1) * align;
+  const void *p{desc.base_addr};
+  const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>(
+      static_cast<const char *>(p) + byteSize)};
+  return *footer == ~reinterpret_cast<std::uintptr_t>(p);
+}
+
 int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
     const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
@@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
   if (!pointer.IsAllocated()) {
     return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
   }
-  if (executionEnvironment.checkPointerDeallocation) {
-    // Validate the footer.  This should fail if the pointer doesn't
-    // span the entire object, or the object was not allocated as a
-    // pointer.
-    std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()};
-    constexpr std::size_t align{sizeof(std::uintptr_t)};
-    byteSize = ((byteSize + align - 1) / align) * align;
-    void *p{pointer.raw().base_addr};
-    std::uintptr_t *footer{
-        reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
-    if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) {
-      return ReturnError(
-          terminator, StatBadPointerDeallocation, errMsg, hasStat);
-    }
+  if (executionEnvironment.checkPointerDeallocation &&
+      !ValidatePointerPayload(pointer.raw())) {
+    return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat);
   }
   return ReturnError(terminator,
       pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator),

``````````

</details>


https://github.com/llvm/llvm-project/pull/96100


More information about the flang-commits mailing list