[llvm] 01bf8cd - [ADT] Support const-qualified unique_functions
Sam McCall via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 29 11:13:55 PDT 2020
Author: Sam McCall
Date: 2020-06-29T20:13:42+02:00
New Revision: 01bf8cdf5fa9bc71869e15e5e351b2b68c39feb6
URL: https://github.com/llvm/llvm-project/commit/01bf8cdf5fa9bc71869e15e5e351b2b68c39feb6
DIFF: https://github.com/llvm/llvm-project/commit/01bf8cdf5fa9bc71869e15e5e351b2b68c39feb6.diff
LOG: [ADT] Support const-qualified unique_functions
Summary:
This technique should extend to rvalue-qualified etc, but I didn't add any.
I removed "volatile" from the future plans, which seems... speculative at best.
While here I moved the callbacks object out of the constructor into a
variable template, which I believe addresses the fixme there about unused
objects.
(I'm not a template guru, so it's always possible the old version was designed
for compile-time performance in a way I'm missing)
Reviewers: kadircet
Subscribers: dexonsmith, llvm-commits, chandlerc
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D82581
Added:
Modified:
llvm/include/llvm/ADT/FunctionExtras.h
llvm/unittests/ADT/FunctionExtrasTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/FunctionExtras.h b/llvm/include/llvm/ADT/FunctionExtras.h
index ad84bbc35b78..8d4c3fe830c3 100644
--- a/llvm/include/llvm/ADT/FunctionExtras.h
+++ b/llvm/include/llvm/ADT/FunctionExtras.h
@@ -11,11 +11,11 @@
/// in `<function>`.
///
/// It provides `unique_function`, which works like `std::function` but supports
-/// move-only callable objects.
+/// move-only callable objects and const-qualification.
///
/// Future plans:
-/// - Add a `function` that provides const, volatile, and ref-qualified support,
-/// which doesn't work with `std::function`.
+/// - Add a `function` that provides ref-qualified support, which doesn't work
+/// with `std::function`.
/// - Provide support for specifying multiple signatures to type erase callable
/// objects with an overload set, such as those produced by generic lambdas.
/// - Expand to include a copyable utility that directly replaces std::function
@@ -37,13 +37,31 @@
#include "llvm/Support/MemAlloc.h"
#include "llvm/Support/type_traits.h"
#include <memory>
+#include <type_traits>
namespace llvm {
+/// unique_function is a type-erasing functor similar to std::function.
+///
+/// It can hold move-only function objects, like lambdas capturing unique_ptrs.
+/// Accordingly, it is movable but not copyable.
+///
+/// It supports const-qualification:
+/// - unique_function<int() const> has a const operator().
+/// It can only hold functions which themselves have a const operator().
+/// - unique_function<int()> has a non-const operator().
+/// It can hold functions with a non-const operator(), like mutable lambdas.
template <typename FunctionT> class unique_function;
-template <typename ReturnT, typename... ParamTs>
-class unique_function<ReturnT(ParamTs...)> {
+namespace detail {
+
+template <typename T>
+using EnableIfTrivial =
+ std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
+ std::is_trivially_destructible<T>::value>;
+
+template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
+protected:
static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
// MSVC has a bug and ICEs if we give it a particular dependent value
@@ -113,8 +131,11 @@ class unique_function<ReturnT(ParamTs...)> {
// For in-line storage, we just provide an aligned character buffer. We
// provide three pointers worth of storage here.
- typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
- InlineStorage;
+ // This is mutable as an inlined `const unique_function<void() const>` may
+ // still modify its own mutable members.
+ mutable
+ typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
+ InlineStorage;
} StorageUnion;
// A compressed pointer to either our dispatching callback or our table of
@@ -137,11 +158,25 @@ class unique_function<ReturnT(ParamTs...)> {
.template get<NonTrivialCallbacks *>();
}
- void *getInlineStorage() { return &StorageUnion.InlineStorage; }
+ CallPtrT getCallPtr() const {
+ return isTrivialCallback() ? getTrivialCallback()
+ : getNonTrivialCallbacks()->CallPtr;
+ }
- void *getOutOfLineStorage() {
+ // These three functions are only const in the narrow sense. They return
+ // mutable pointers to function state.
+ // This allows unique_function<T const>::operator() to be const, even if the
+ // underlying functor may be internally mutable.
+ //
+ // const callers must ensure they're only used in const-correct ways.
+ void *getCalleePtr() const {
+ return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
+ }
+ void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
+ void *getOutOfLineStorage() const {
return StorageUnion.OutOfLineStorage.StoragePtr;
}
+
size_t getOutOfLineStorageSize() const {
return StorageUnion.OutOfLineStorage.Size;
}
@@ -153,10 +188,11 @@ class unique_function<ReturnT(ParamTs...)> {
StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
}
- template <typename CallableT>
- static ReturnT CallImpl(void *CallableAddr, AdjustedParamT<ParamTs>... Params) {
- return (*reinterpret_cast<CallableT *>(CallableAddr))(
- std::forward<ParamTs>(Params)...);
+ template <typename CalledAsT>
+ static ReturnT CallImpl(void *CallableAddr,
+ AdjustedParamT<ParamTs>... Params) {
+ auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
+ return Func(std::forward<ParamTs>(Params)...);
}
template <typename CallableT>
@@ -170,11 +206,49 @@ class unique_function<ReturnT(ParamTs...)> {
reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
}
-public:
- unique_function() = default;
- unique_function(std::nullptr_t /*null_callable*/) {}
+ // The pointers to call/move/destroy functions are determined for each
+ // callable type (and called-as type, which determines the overload chosen).
+ // (definitions are out-of-line).
+
+ // By default, we need an object that contains all the
diff erent
+ // type erased behaviors needed. Create a static instance of the struct type
+ // here and each instance will contain a pointer to it.
+ template <typename CallableT, typename CalledAs, typename Enable = void>
+ static NonTrivialCallbacks Callbacks;
+ // See if we can create a trivial callback. We need the callable to be
+ // trivially moved and trivially destroyed so that we don't have to store
+ // type erased callbacks for those operations.
+ template <typename CallableT, typename CalledAs>
+ static TrivialCallback
+ Callbacks<CallableT, CalledAs, EnableIfTrivial<CallableT>>;
+
+ // A simple tag type so the call-as type to be passed to the constructor.
+ template <typename T> struct CalledAs {};
+
+ // Essentially the "main" unique_function constructor, but subclasses
+ // provide the qualified type to be used for the call.
+ // (We always store a T, even if the call will use a pointer to const T).
+ template <typename CallableT, typename CalledAsT>
+ UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
+ bool IsInlineStorage = true;
+ void *CallableAddr = getInlineStorage();
+ if (sizeof(CallableT) > InlineStorageSize ||
+ alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
+ IsInlineStorage = false;
+ // Allocate out-of-line storage. FIXME: Use an explicit alignment
+ // parameter in C++17 mode.
+ auto Size = sizeof(CallableT);
+ auto Alignment = alignof(CallableT);
+ CallableAddr = allocate_buffer(Size, Alignment);
+ setOutOfLineStorage(CallableAddr, Size, Alignment);
+ }
+
+ // Now move into the storage.
+ new (CallableAddr) CallableT(std::move(Callable));
+ CallbackAndInlineFlag = {&Callbacks<CallableT, CalledAsT>, IsInlineStorage};
+ }
- ~unique_function() {
+ ~UniqueFunctionBase() {
if (!CallbackAndInlineFlag.getPointer())
return;
@@ -190,7 +264,7 @@ class unique_function<ReturnT(ParamTs...)> {
getOutOfLineStorageAlignment());
}
- unique_function(unique_function &&RHS) noexcept {
+ UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
// Copy the callback and inline flag.
CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
@@ -219,72 +293,82 @@ class unique_function<ReturnT(ParamTs...)> {
#endif
}
- unique_function &operator=(unique_function &&RHS) noexcept {
+ UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
if (this == &RHS)
return *this;
// Because we don't try to provide any exception safety guarantees we can
// implement move assignment very simply by first destroying the current
// object and then move-constructing over top of it.
- this->~unique_function();
- new (this) unique_function(std::move(RHS));
+ this->~UniqueFunctionBase();
+ new (this) UniqueFunctionBase(std::move(RHS));
return *this;
}
- template <typename CallableT> unique_function(CallableT Callable) {
- bool IsInlineStorage = true;
- void *CallableAddr = getInlineStorage();
- if (sizeof(CallableT) > InlineStorageSize ||
- alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
- IsInlineStorage = false;
- // Allocate out-of-line storage. FIXME: Use an explicit alignment
- // parameter in C++17 mode.
- auto Size = sizeof(CallableT);
- auto Alignment = alignof(CallableT);
- CallableAddr = allocate_buffer(Size, Alignment);
- setOutOfLineStorage(CallableAddr, Size, Alignment);
- }
+ UniqueFunctionBase() = default;
- // Now move into the storage.
- new (CallableAddr) CallableT(std::move(Callable));
+public:
+ explicit operator bool() const {
+ return (bool)CallbackAndInlineFlag.getPointer();
+ }
+};
- // See if we can create a trivial callback. We need the callable to be
- // trivially moved and trivially destroyed so that we don't have to store
- // type erased callbacks for those operations.
- //
- // FIXME: We should use constexpr if here and below to avoid instantiating
- // the non-trivial static objects when unnecessary. While the linker should
- // remove them, it is still wasteful.
- if (llvm::is_trivially_move_constructible<CallableT>::value &&
- std::is_trivially_destructible<CallableT>::value) {
- // We need to create a nicely aligned object. We use a static variable
- // for this because it is a trivial struct.
- static TrivialCallback Callback = { &CallImpl<CallableT> };
-
- CallbackAndInlineFlag = {&Callback, IsInlineStorage};
- return;
- }
+template <typename R, typename... P>
+template <typename CallableT, typename CalledAsT, typename Enable>
+typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks
+ UniqueFunctionBase<R, P...>::Callbacks = {
+ &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
- // Otherwise, we need to point at an object that contains all the
diff erent
- // type erased behaviors needed. Create a static instance of the struct type
- // here and then use a pointer to that.
- static NonTrivialCallbacks Callbacks = {
- &CallImpl<CallableT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
+template <typename R, typename... P>
+template <typename CallableT, typename CalledAsT>
+typename UniqueFunctionBase<R, P...>::TrivialCallback UniqueFunctionBase<
+ R, P...>::Callbacks<CallableT, CalledAsT, EnableIfTrivial<CallableT>>{
+ &CallImpl<CalledAsT>};
- CallbackAndInlineFlag = {&Callbacks, IsInlineStorage};
- }
+} // namespace detail
+
+template <typename R, typename... P>
+class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
+ using Base = detail::UniqueFunctionBase<R, P...>;
+
+public:
+ unique_function() = default;
+ unique_function(std::nullptr_t) {}
+ unique_function(unique_function &&) = default;
+ unique_function(const unique_function &) = delete;
+ unique_function &operator=(unique_function &&) = default;
+ unique_function &operator=(const unique_function &) = delete;
- ReturnT operator()(ParamTs... Params) {
- void *CallableAddr =
- isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
+ template <typename CallableT>
+ unique_function(CallableT Callable)
+ : Base(std::forward<CallableT>(Callable),
+ typename Base::template CalledAs<CallableT>{}) {}
- return (isTrivialCallback()
- ? getTrivialCallback()
- : getNonTrivialCallbacks()->CallPtr)(CallableAddr, Params...);
+ R operator()(P... Params) {
+ return this->getCallPtr()(this->getCalleePtr(), Params...);
}
+};
- explicit operator bool() const {
- return (bool)CallbackAndInlineFlag.getPointer();
+template <typename R, typename... P>
+class unique_function<R(P...) const>
+ : public detail::UniqueFunctionBase<R, P...> {
+ using Base = detail::UniqueFunctionBase<R, P...>;
+
+public:
+ unique_function() = default;
+ unique_function(std::nullptr_t) {}
+ unique_function(unique_function &&) = default;
+ unique_function(const unique_function &) = delete;
+ unique_function &operator=(unique_function &&) = default;
+ unique_function &operator=(const unique_function &) = delete;
+
+ template <typename CallableT>
+ unique_function(CallableT Callable)
+ : Base(std::forward<CallableT>(Callable),
+ typename Base::template CalledAs<const CallableT>{}) {}
+
+ R operator()(P... Params) const {
+ return this->getCallPtr()(this->getCalleePtr(), Params...);
}
};
diff --git a/llvm/unittests/ADT/FunctionExtrasTest.cpp b/llvm/unittests/ADT/FunctionExtrasTest.cpp
index bbbb045cb14a..2ae0d1813858 100644
--- a/llvm/unittests/ADT/FunctionExtrasTest.cpp
+++ b/llvm/unittests/ADT/FunctionExtrasTest.cpp
@@ -10,6 +10,7 @@
#include "gtest/gtest.h"
#include <memory>
+#include <type_traits>
using namespace llvm;
@@ -224,4 +225,41 @@ TEST(UniqueFunctionTest, CountForwardingMoves) {
UnmovableF(X);
}
+TEST(UniqueFunctionTest, Const) {
+ // Can assign from const lambda.
+ unique_function<int(int) const> Plus2 = [X(std::make_unique<int>(2))](int Y) {
+ return *X + Y;
+ };
+ EXPECT_EQ(5, Plus2(3));
+
+ // Can call through a const ref.
+ const auto &Plus2Ref = Plus2;
+ EXPECT_EQ(5, Plus2Ref(3));
+
+ // Can move-construct and assign.
+ unique_function<int(int) const> Plus2A = std::move(Plus2);
+ EXPECT_EQ(5, Plus2A(3));
+ unique_function<int(int) const> Plus2B;
+ Plus2B = std::move(Plus2A);
+ EXPECT_EQ(5, Plus2B(3));
+
+ // Can convert to non-const function type, but not back.
+ unique_function<int(int)> Plus2C = std::move(Plus2B);
+ EXPECT_EQ(5, Plus2C(3));
+
+ // Overloaded call operator correctly resolved.
+ struct ChooseCorrectOverload {
+ StringRef operator()() { return "non-const"; }
+ StringRef operator()() const { return "const"; }
+ };
+ unique_function<StringRef()> ChooseMutable = ChooseCorrectOverload();
+ ChooseCorrectOverload A;
+ EXPECT_EQ("non-const", ChooseMutable());
+ EXPECT_EQ("non-const", A());
+ unique_function<StringRef() const> ChooseConst = ChooseCorrectOverload();
+ const ChooseCorrectOverload &X = A;
+ EXPECT_EQ("const", ChooseConst());
+ EXPECT_EQ("const", X());
+}
+
} // anonymous namespace
More information about the llvm-commits
mailing list