[libc-commits] [libc] [libc][stdlib] Fix UB in freelist (PR #95330)

via libc-commits libc-commits at lists.llvm.org
Thu Jun 13 14:48:57 PDT 2024


https://github.com/PiJoules updated https://github.com/llvm/llvm-project/pull/95330

>From e304a815a3b582e3677a66ac82bb4378b6dc3280 Mon Sep 17 00:00:00 2001
From: Leonard Chan <leonardchan at google.com>
Date: Wed, 12 Jun 2024 16:16:08 -0700
Subject: [PATCH] [libc][stdlib] Fix UB in freelist

Some of the freelist code uses type punning which is UB in C++, namely
because we read from a union member that is not the active union
member.
---
 libc/src/stdlib/CMakeLists.txt |  3 +-
 libc/src/stdlib/freelist.h     | 53 ++++++++++++----------------------
 2 files changed, 20 insertions(+), 36 deletions(-)

diff --git a/libc/src/stdlib/CMakeLists.txt b/libc/src/stdlib/CMakeLists.txt
index 971b39bb900de..6d2c5acca9605 100644
--- a/libc/src/stdlib/CMakeLists.txt
+++ b/libc/src/stdlib/CMakeLists.txt
@@ -398,8 +398,9 @@ else()
       freelist.h
     DEPENDS
       libc.src.__support.fixedvector
-      libc.src.__support.CPP.cstddef
       libc.src.__support.CPP.array
+      libc.src.__support.CPP.cstddef
+      libc.src.__support.CPP.new
       libc.src.__support.CPP.span
   )
   add_header_library(
diff --git a/libc/src/stdlib/freelist.h b/libc/src/stdlib/freelist.h
index c01ed6eddb7d4..789bc164fb161 100644
--- a/libc/src/stdlib/freelist.h
+++ b/libc/src/stdlib/freelist.h
@@ -11,6 +11,7 @@
 
 #include "src/__support/CPP/array.h"
 #include "src/__support/CPP/cstddef.h"
+#include "src/__support/CPP/new.h"
 #include "src/__support/CPP/span.h"
 #include "src/__support/fixedvector.h"
 
@@ -92,19 +93,12 @@ bool FreeList<NUM_BUCKETS>::add_chunk(span<cpp::byte> chunk) {
   if (chunk.size() < sizeof(FreeListNode))
     return false;
 
-  union {
-    FreeListNode *node;
-    cpp::byte *bytes;
-  } aliased;
-
-  aliased.bytes = chunk.data();
-
+  // Add it to the correct list.
   size_t chunk_ptr = find_chunk_ptr_for_size(chunk.size(), false);
 
-  // Add it to the correct list.
-  aliased.node->size = chunk.size();
-  aliased.node->next = chunks_[chunk_ptr];
-  chunks_[chunk_ptr] = aliased.node;
+  FreeListNode *node =
+      ::new (chunk.data()) FreeListNode{chunks_[chunk_ptr], chunk.size()};
+  chunks_[chunk_ptr] = node;
 
   return true;
 }
@@ -123,17 +117,13 @@ span<cpp::byte> FreeList<NUM_BUCKETS>::find_chunk(size_t size) const {
 
   // Now iterate up the buckets, walking each list to find a good candidate
   for (size_t i = chunk_ptr; i < chunks_.size(); i++) {
-    union {
-      FreeListNode *node;
-      cpp::byte *data;
-    } aliased;
-    aliased.node = chunks_[static_cast<unsigned short>(i)];
+    FreeListNode *node = chunks_[static_cast<unsigned short>(i)];
 
-    while (aliased.node != nullptr) {
-      if (aliased.node->size >= size)
-        return span<cpp::byte>(aliased.data, aliased.node->size);
+    while (node != nullptr) {
+      if (node->size >= size)
+        return span<cpp::byte>(reinterpret_cast<cpp::byte *>(node), node->size);
 
-      aliased.node = aliased.node->next;
+      node = node->next;
     }
   }
 
@@ -146,34 +136,27 @@ template <size_t NUM_BUCKETS>
 bool FreeList<NUM_BUCKETS>::remove_chunk(span<cpp::byte> chunk) {
   size_t chunk_ptr = find_chunk_ptr_for_size(chunk.size(), true);
 
-  // Walk that list, finding the chunk.
-  union {
-    FreeListNode *node;
-    cpp::byte *data;
-  } aliased, aliased_next;
-
   // Check head first.
   if (chunks_[chunk_ptr] == nullptr)
     return false;
 
-  aliased.node = chunks_[chunk_ptr];
-  if (aliased.data == chunk.data()) {
-    chunks_[chunk_ptr] = aliased.node->next;
+  FreeListNode *node = chunks_[chunk_ptr];
+  if (reinterpret_cast<cpp::byte *>(node) == chunk.data()) {
+    chunks_[chunk_ptr] = node->next;
     return true;
   }
 
   // No? Walk the nodes.
-  aliased.node = chunks_[chunk_ptr];
+  node = chunks_[chunk_ptr];
 
-  while (aliased.node->next != nullptr) {
-    aliased_next.node = aliased.node->next;
-    if (aliased_next.data == chunk.data()) {
+  while (node->next != nullptr) {
+    if (reinterpret_cast<cpp::byte *>(node->next) == chunk.data()) {
       // Found it, remove this node out of the chain
-      aliased.node->next = aliased_next.node->next;
+      node->next = node->next->next;
       return true;
     }
 
-    aliased.node = aliased.node->next;
+    node = node->next;
   }
 
   return false;



More information about the libc-commits mailing list