[libcxx-commits] [libcxx] [libc++] Fix UB in <expected> related to "has value" flag (#68552) (PR #68733)

Jan Kokemüller via libcxx-commits libcxx-commits at lists.llvm.org
Wed Oct 11 11:42:56 PDT 2023


https://github.com/jiixyj updated https://github.com/llvm/llvm-project/pull/68733

>From 9832a97756cdb800947827ff881f0509159bcc12 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Tue, 10 Oct 2023 20:24:06 +0200
Subject: [PATCH] [libc++] Fix UB in <expected> related to "has value" flag
 (#68552)

The calls to `std::construct_at` might overwrite the previously set
`__has_value_` flag, so reverse the order everywhere. Where possible,
avoid calling `std::construct_at` and construct the value/error
directly into the union.
---
 libcxx/include/__expected/expected.h          | 168 ++++++++----------
 .../observers/has_value.pass.cpp              |  39 ++++
 2 files changed, 117 insertions(+), 90 deletions(-)

diff --git a/libcxx/include/__expected/expected.h b/libcxx/include/__expected/expected.h
index 045370a486fae6b..08f35b1111f6bf9 100644
--- a/libcxx/include/__expected/expected.h
+++ b/libcxx/include/__expected/expected.h
@@ -119,9 +119,7 @@ class expected {
   _LIBCPP_HIDE_FROM_ABI constexpr expected()
     noexcept(is_nothrow_default_constructible_v<_Tp>) // strengthened
     requires is_default_constructible_v<_Tp>
-      : __has_val_(true) {
-    std::construct_at(std::addressof(__union_.__val_));
-  }
+      : __union_(__construct_in_place_tag{}), __has_val_(true) {}
 
   _LIBCPP_HIDE_FROM_ABI constexpr expected(const expected&) = delete;
 
@@ -136,14 +134,7 @@ class expected {
     noexcept(is_nothrow_copy_constructible_v<_Tp> && is_nothrow_copy_constructible_v<_Err>) // strengthened
     requires(is_copy_constructible_v<_Tp> && is_copy_constructible_v<_Err> &&
              !(is_trivially_copy_constructible_v<_Tp> && is_trivially_copy_constructible_v<_Err>))
-      : __has_val_(__other.__has_val_) {
-    if (__has_val_) {
-      std::construct_at(std::addressof(__union_.__val_), __other.__union_.__val_);
-    } else {
-      std::construct_at(std::addressof(__union_.__unex_), __other.__union_.__unex_);
-    }
-  }
-
+      : __union_(__union_from_expected(__other)), __has_val_(__other.__has_val_) { }
 
   _LIBCPP_HIDE_FROM_ABI constexpr expected(expected&&)
     requires(is_move_constructible_v<_Tp> && is_move_constructible_v<_Err>
@@ -154,13 +145,7 @@ class expected {
     noexcept(is_nothrow_move_constructible_v<_Tp> && is_nothrow_move_constructible_v<_Err>)
     requires(is_move_constructible_v<_Tp> && is_move_constructible_v<_Err> &&
              !(is_trivially_move_constructible_v<_Tp> && is_trivially_move_constructible_v<_Err>))
-      : __has_val_(__other.__has_val_) {
-    if (__has_val_) {
-      std::construct_at(std::addressof(__union_.__val_), std::move(__other.__union_.__val_));
-    } else {
-      std::construct_at(std::addressof(__union_.__unex_), std::move(__other.__union_.__unex_));
-    }
-  }
+      : __union_(__union_from_expected(std::move(__other))), __has_val_(__other.__has_val_) { }
 
 private:
   template <class _Up, class _OtherErr, class _UfQual, class _OtherErrQual>
@@ -200,26 +185,14 @@ class expected {
   expected(const expected<_Up, _OtherErr>& __other)
     noexcept(is_nothrow_constructible_v<_Tp, const _Up&> &&
              is_nothrow_constructible_v<_Err, const _OtherErr&>) // strengthened
-      : __has_val_(__other.__has_val_) {
-    if (__has_val_) {
-      std::construct_at(std::addressof(__union_.__val_), __other.__union_.__val_);
-    } else {
-      std::construct_at(std::addressof(__union_.__unex_), __other.__union_.__unex_);
-    }
-  }
+      : __union_(__union_from_expected(__other)), __has_val_(__other.__has_val_) {}
 
   template <class _Up, class _OtherErr>
     requires __can_convert<_Up, _OtherErr, _Up, _OtherErr>::value
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<_Up, _Tp> || !is_convertible_v<_OtherErr, _Err>)
   expected(expected<_Up, _OtherErr>&& __other)
     noexcept(is_nothrow_constructible_v<_Tp, _Up> && is_nothrow_constructible_v<_Err, _OtherErr>) // strengthened
-      : __has_val_(__other.__has_val_) {
-    if (__has_val_) {
-      std::construct_at(std::addressof(__union_.__val_), std::move(__other.__union_.__val_));
-    } else {
-      std::construct_at(std::addressof(__union_.__unex_), std::move(__other.__union_.__unex_));
-    }
-  }
+      : __union_(__union_from_expected(std::move(__other))), __has_val_(__other.__has_val_) {}
 
   template <class _Up = _Tp>
     requires(!is_same_v<remove_cvref_t<_Up>, in_place_t> && !is_same_v<expected, remove_cvref_t<_Up>> &&
@@ -227,61 +200,47 @@ class expected {
              (!is_same_v<remove_cv_t<_Tp>, bool> || !__is_std_expected<remove_cvref_t<_Up>>::value))
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<_Up, _Tp>)
       expected(_Up&& __u) noexcept(is_nothrow_constructible_v<_Tp, _Up>) // strengthened
-      : __has_val_(true) {
-    std::construct_at(std::addressof(__union_.__val_), std::forward<_Up>(__u));
-  }
+      : __union_(__construct_in_place_tag{}, std::forward<_Up>(__u)), __has_val_(true) {}
 
   template <class _OtherErr>
     requires is_constructible_v<_Err, const _OtherErr&>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<const _OtherErr&, _Err>)
   expected(const unexpected<_OtherErr>& __unex)
     noexcept(is_nothrow_constructible_v<_Err, const _OtherErr&>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), __unex.error());
-  }
+      : __union_(__construct_unexpected_tag{}, __unex.error()), __has_val_(false) {}
 
   template <class _OtherErr>
     requires is_constructible_v<_Err, _OtherErr>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<_OtherErr, _Err>)
   expected(unexpected<_OtherErr>&& __unex)
     noexcept(is_nothrow_constructible_v<_Err, _OtherErr>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), std::move(__unex.error()));
-  }
+      : __union_(__construct_unexpected_tag{}, std::move(__unex.error())), __has_val_(false) {}
 
   template <class... _Args>
     requires is_constructible_v<_Tp, _Args...>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit expected(in_place_t, _Args&&... __args)
     noexcept(is_nothrow_constructible_v<_Tp, _Args...>) // strengthened
-      : __has_val_(true) {
-    std::construct_at(std::addressof(__union_.__val_), std::forward<_Args>(__args)...);
-  }
+      : __union_(__construct_in_place_tag{}, std::forward<_Args>(__args)...), __has_val_(true) {}
 
   template <class _Up, class... _Args>
     requires is_constructible_v< _Tp, initializer_list<_Up>&, _Args... >
   _LIBCPP_HIDE_FROM_ABI constexpr explicit
   expected(in_place_t, initializer_list<_Up> __il, _Args&&... __args)
     noexcept(is_nothrow_constructible_v<_Tp, initializer_list<_Up>&, _Args...>) // strengthened
-      : __has_val_(true) {
-    std::construct_at(std::addressof(__union_.__val_), __il, std::forward<_Args>(__args)...);
-  }
+      : __union_(__construct_in_place_tag{}, __il, std::forward<_Args>(__args)...), __has_val_(true) {}
 
   template <class... _Args>
     requires is_constructible_v<_Err, _Args...>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit expected(unexpect_t, _Args&&... __args)
-    noexcept(is_nothrow_constructible_v<_Err, _Args...>)  // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), std::forward<_Args>(__args)...);
-  }
+    noexcept(is_nothrow_constructible_v<_Err, _Args...>) // strengthened
+      : __union_(__construct_unexpected_tag{}, std::forward<_Args>(__args)...), __has_val_(false) {}
 
   template <class _Up, class... _Args>
     requires is_constructible_v< _Err, initializer_list<_Up>&, _Args... >
   _LIBCPP_HIDE_FROM_ABI constexpr explicit
   expected(unexpect_t, initializer_list<_Up> __il, _Args&&... __args)
     noexcept(is_nothrow_constructible_v<_Err, initializer_list<_Up>&, _Args...>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), __il, std::forward<_Args>(__args)...);
-  }
+      : __union_(__construct_unexpected_tag{}, __il, std::forward<_Args>(__args)...), __has_val_(false) {}
 
   // [expected.object.dtor], destructor
 
@@ -440,9 +399,10 @@ class expected {
       std::destroy_at(std::addressof(__union_.__val_));
     } else {
       std::destroy_at(std::addressof(__union_.__unex_));
-      __has_val_ = true;
     }
-    return *std::construct_at(std::addressof(__union_.__val_), std::forward<_Args>(__args)...);
+    std::construct_at(std::addressof(__union_.__val_), std::forward<_Args>(__args)...);
+    __has_val_ = true;
+    return *std::addressof(__union_.__val_);
   }
 
   template <class _Up, class... _Args>
@@ -452,9 +412,10 @@ class expected {
       std::destroy_at(std::addressof(__union_.__val_));
     } else {
       std::destroy_at(std::addressof(__union_.__unex_));
-      __has_val_ = true;
     }
-    return *std::construct_at(std::addressof(__union_.__val_), __il, std::forward<_Args>(__args)...);
+    std::construct_at(std::addressof(__union_.__val_), __il, std::forward<_Args>(__args)...);
+    __has_val_ = true;
+    return *std::addressof(__union_.__val_);
   }
 
 
@@ -894,11 +855,21 @@ class expected {
 
 private:
   struct __empty_t {};
+  struct __construct_in_place_tag {};
+  struct __construct_unexpected_tag {};
 
   template <class _ValueType, class _ErrorType>
   union __union_t {
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t() {}
 
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_in_place_tag, _Args&&... __args)
+        : __val_(std::forward<_Args>(__args)...) {}
+
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_unexpected_tag, _Args&&... __args)
+        : __unex_(std::forward<_Args>(__args)...) {}
+
     template <class _Func, class... _Args>
     _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(
         std::__expected_construct_in_place_from_invoke_tag, _Func&& __f, _Args&&... __args)
@@ -931,6 +902,14 @@ class expected {
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t(const __union_t&) = default;
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t& operator=(const __union_t&) = default;
 
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_in_place_tag, _Args&&... __args)
+        : __val_(std::forward<_Args>(__args)...) {}
+
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_unexpected_tag, _Args&&... __args)
+        : __unex_(std::forward<_Args>(__args)...) {}
+
     template <class _Func, class... _Args>
     _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(
         std::__expected_construct_in_place_from_invoke_tag, _Func&& __f, _Args&&... __args)
@@ -955,6 +934,18 @@ class expected {
     _LIBCPP_NO_UNIQUE_ADDRESS _ErrorType __unex_;
   };
 
+  template <class _Up, class _OtherErr>
+  _LIBCPP_HIDE_FROM_ABI constexpr __union_t<_Tp, _Err> __union_from_expected(const expected<_Up, _OtherErr>& __other) {
+    return __other.__has_val_ ? __union_t<_Tp, _Err>(__construct_in_place_tag{}, __other.__union_.__val_)
+                              : __union_t<_Tp, _Err>(__construct_unexpected_tag{}, __other.__union_.__unex_);
+  }
+
+  template <class _Up, class _OtherErr>
+  _LIBCPP_HIDE_FROM_ABI constexpr __union_t<_Tp, _Err> __union_from_expected(expected<_Up, _OtherErr>&& __other) {
+    return __other.__has_val_ ? __union_t<_Tp, _Err>(__construct_in_place_tag{}, std::move(__other.__union_.__val_))
+                              : __union_t<_Tp, _Err>(__construct_unexpected_tag{}, std::move(__other.__union_.__unex_));
+  }
+
   _LIBCPP_NO_UNIQUE_ADDRESS __union_t<_Tp, _Err> __union_;
   bool __has_val_;
 };
@@ -998,11 +989,7 @@ class expected<_Tp, _Err> {
   _LIBCPP_HIDE_FROM_ABI constexpr expected(const expected& __rhs)
     noexcept(is_nothrow_copy_constructible_v<_Err>) // strengthened
     requires(is_copy_constructible_v<_Err> && !is_trivially_copy_constructible_v<_Err>)
-      : __has_val_(__rhs.__has_val_) {
-    if (!__rhs.__has_val_) {
-      std::construct_at(std::addressof(__union_.__unex_), __rhs.__union_.__unex_);
-    }
-  }
+      : __union_(__union_from_expected(__rhs)), __has_val_(__rhs.__has_val_) {}
 
   _LIBCPP_HIDE_FROM_ABI constexpr expected(expected&&)
     requires(is_move_constructible_v<_Err> && is_trivially_move_constructible_v<_Err>)
@@ -1011,51 +998,35 @@ class expected<_Tp, _Err> {
   _LIBCPP_HIDE_FROM_ABI constexpr expected(expected&& __rhs)
     noexcept(is_nothrow_move_constructible_v<_Err>)
     requires(is_move_constructible_v<_Err> && !is_trivially_move_constructible_v<_Err>)
-      : __has_val_(__rhs.__has_val_) {
-    if (!__rhs.__has_val_) {
-      std::construct_at(std::addressof(__union_.__unex_), std::move(__rhs.__union_.__unex_));
-    }
-  }
+      : __union_(__union_from_expected(std::move(__rhs))), __has_val_(__rhs.__has_val_) {}
 
   template <class _Up, class _OtherErr>
     requires __can_convert<_Up, _OtherErr, const _OtherErr&>::value
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<const _OtherErr&, _Err>)
   expected(const expected<_Up, _OtherErr>& __rhs)
     noexcept(is_nothrow_constructible_v<_Err, const _OtherErr&>) // strengthened
-      : __has_val_(__rhs.__has_val_) {
-    if (!__rhs.__has_val_) {
-      std::construct_at(std::addressof(__union_.__unex_), __rhs.__union_.__unex_);
-    }
-  }
+      : __union_(__union_from_expected(__rhs)), __has_val_(__rhs.__has_val_) {}
 
   template <class _Up, class _OtherErr>
     requires __can_convert<_Up, _OtherErr, _OtherErr>::value
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<_OtherErr, _Err>)
   expected(expected<_Up, _OtherErr>&& __rhs)
     noexcept(is_nothrow_constructible_v<_Err, _OtherErr>) // strengthened
-      : __has_val_(__rhs.__has_val_) {
-    if (!__rhs.__has_val_) {
-      std::construct_at(std::addressof(__union_.__unex_), std::move(__rhs.__union_.__unex_));
-    }
-  }
+      : __union_(__union_from_expected(std::move(__rhs))), __has_val_(__rhs.__has_val_) {}
 
   template <class _OtherErr>
     requires is_constructible_v<_Err, const _OtherErr&>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<const _OtherErr&, _Err>)
   expected(const unexpected<_OtherErr>& __unex)
     noexcept(is_nothrow_constructible_v<_Err, const _OtherErr&>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), __unex.error());
-  }
+      : __union_(__construct_unexpected_tag{}, __unex.error()), __has_val_(false) {}
 
   template <class _OtherErr>
     requires is_constructible_v<_Err, _OtherErr>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit(!is_convertible_v<_OtherErr, _Err>)
   expected(unexpected<_OtherErr>&& __unex)
     noexcept(is_nothrow_constructible_v<_Err, _OtherErr>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), std::move(__unex.error()));
-  }
+      : __union_(__construct_unexpected_tag{}, std::move(__unex.error())), __has_val_(false) {}
 
   _LIBCPP_HIDE_FROM_ABI constexpr explicit expected(in_place_t) noexcept : __has_val_(true) {}
 
@@ -1063,17 +1034,13 @@ class expected<_Tp, _Err> {
     requires is_constructible_v<_Err, _Args...>
   _LIBCPP_HIDE_FROM_ABI constexpr explicit expected(unexpect_t, _Args&&... __args)
     noexcept(is_nothrow_constructible_v<_Err, _Args...>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), std::forward<_Args>(__args)...);
-  }
+      : __union_(__construct_unexpected_tag{}, std::forward<_Args>(__args)...), __has_val_(false) {}
 
   template <class _Up, class... _Args>
     requires is_constructible_v< _Err, initializer_list<_Up>&, _Args... >
   _LIBCPP_HIDE_FROM_ABI constexpr explicit expected(unexpect_t, initializer_list<_Up> __il, _Args&&... __args)
     noexcept(is_nothrow_constructible_v<_Err, initializer_list<_Up>&, _Args...>) // strengthened
-      : __has_val_(false) {
-    std::construct_at(std::addressof(__union_.__unex_), __il, std::forward<_Args>(__args)...);
-  }
+      : __union_(__construct_unexpected_tag{}, __il, std::forward<_Args>(__args)...), __has_val_(false) {}
 
 private:
   template <class _Func>
@@ -1502,11 +1469,16 @@ class expected<_Tp, _Err> {
 
 private:
   struct __empty_t {};
+  struct __construct_unexpected_tag {};
 
   template <class _ErrorType>
   union __union_t {
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t() : __empty_() {}
 
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_unexpected_tag, _Args&&... __args)
+        : __unex_(std::forward<_Args>(__args)...) {}
+
     template <class _Func, class... _Args>
     _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(
         __expected_construct_unexpected_from_invoke_tag, _Func&& __f, _Args&&... __args)
@@ -1534,6 +1506,10 @@ class expected<_Tp, _Err> {
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t(const __union_t&) = default;
     _LIBCPP_HIDE_FROM_ABI constexpr __union_t& operator=(const __union_t&) = default;
 
+    template <class... _Args>
+    _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(__construct_unexpected_tag, _Args&&... __args)
+        : __unex_(std::forward<_Args>(__args)...) {}
+
     template <class _Func, class... _Args>
     _LIBCPP_HIDE_FROM_ABI constexpr explicit __union_t(
         __expected_construct_unexpected_from_invoke_tag, _Func&& __f, _Args&&... __args)
@@ -1552,6 +1528,18 @@ class expected<_Tp, _Err> {
     _LIBCPP_NO_UNIQUE_ADDRESS _ErrorType __unex_;
   };
 
+  template <class _Up, class _OtherErr>
+  _LIBCPP_HIDE_FROM_ABI constexpr __union_t<_Err> __union_from_expected(const expected<_Up, _OtherErr>& __other) {
+    return __other.__has_val_ ? __union_t<_Err>()
+                              : __union_t<_Err>(__construct_unexpected_tag{}, __other.__union_.__unex_);
+  }
+
+  template <class _Up, class _OtherErr>
+  _LIBCPP_HIDE_FROM_ABI constexpr __union_t<_Err> __union_from_expected(expected<_Up, _OtherErr>&& __other) {
+    return __other.__has_val_ ? __union_t<_Err>()
+                              : __union_t<_Err>(__construct_unexpected_tag{}, std::move(__other.__union_.__unex_));
+  }
+
   _LIBCPP_NO_UNIQUE_ADDRESS __union_t<_Err> __union_;
   bool __has_val_;
 };
diff --git a/libcxx/test/std/utilities/expected/expected.expected/observers/has_value.pass.cpp b/libcxx/test/std/utilities/expected/expected.expected/observers/has_value.pass.cpp
index 27d657a065699ea..8979e0f45d44f50 100644
--- a/libcxx/test/std/utilities/expected/expected.expected/observers/has_value.pass.cpp
+++ b/libcxx/test/std/utilities/expected/expected.expected/observers/has_value.pass.cpp
@@ -12,6 +12,7 @@
 #include <cassert>
 #include <concepts>
 #include <expected>
+#include <optional>
 #include <type_traits>
 #include <utility>
 
@@ -30,6 +31,22 @@ static_assert(!HasValueNoexcept<Foo>);
 static_assert(HasValueNoexcept<std::expected<int, int>>);
 static_assert(HasValueNoexcept<const std::expected<int, int>>);
 
+// This type has one byte of tail padding where `std::expected` will put its
+// "has value" flag. The constructor will clobber all bytes including the
+// tail padding. With this type we can check that `std::expected` will set
+// its "has value" flag _after_ the value/error object is constructed.
+template <int c>
+struct tail_clobberer {
+  constexpr tail_clobberer() {
+    if (!std::is_constant_evaluated()) {
+      // This `memset` might actually be UB (?) but suffices to reproduce bugs
+      // related to the "has value" flag.
+      std::memset(this, c, sizeof(*this));
+    }
+  }
+  alignas(2) bool b;
+};
+
 constexpr bool test() {
   // has_value
   {
@@ -43,6 +60,28 @@ constexpr bool test() {
     assert(!e.has_value());
   }
 
+  // See https://github.com/llvm/llvm-project/issues/68552
+  {
+    static constexpr auto f1 = [] -> std::expected<std::optional<int>, long> { return 0; };
+
+    static constexpr auto f2 = [] -> std::expected<std::optional<int>, int> {
+      return f1().transform_error([](auto) { return 0; });
+    };
+
+    auto e = f2();
+    assert(e.has_value());
+  }
+  {
+    const std::expected<tail_clobberer<0>, bool> e = {};
+    static_assert(sizeof(tail_clobberer<0>) == sizeof(e));
+    assert(e.has_value());
+  }
+  {
+    const std::expected<void, tail_clobberer<1>> e(std::unexpect);
+    static_assert(sizeof(tail_clobberer<1>) == sizeof(e));
+    assert(!e.has_value());
+  }
+
   return true;
 }
 



More information about the libcxx-commits mailing list