[llvm-branch-commits] [flang][runtime] Added Fortran::common::reference_wrapper for use on device. (PR #85178)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 13 22:32:28 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.
---
Full diff: https://github.com/llvm/llvm-project/pull/85178.diff
2 Files Affected:
- (added) flang/include/flang/Common/reference-wrapper.h (+114)
- (modified) flang/runtime/io-stmt.h (+34-25)
``````````diff
diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h
new file mode 100644
index 00000000000000..66f924662d9612
--- /dev/null
+++ b/flang/include/flang/Common/reference-wrapper.h
@@ -0,0 +1,114 @@
+//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// clang-format off
+//
+// Implementation of std::reference_wrapper borrowed from libcu++
+// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
+// with modifications.
+//
+// The original source code is distributed under the Apache License v2.0
+// with LLVM Exceptions.
+//
+// TODO: using libcu++ is the best option for CUDA, but there is a couple
+// of issues:
+// * The include paths need to be set up such that all STD header files
+// are taken from libcu++.
+// * cuda:: namespace need to be forced for all std:: references.
+//
+// clang-format on
+
+#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
+#define FORTRAN_COMMON_REFERENCE_WRAPPER_H
+
+#include "flang/Runtime/api-attrs.h"
+#include <functional>
+#include <type_traits>
+
+#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
+ (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
+#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
+#endif
+
+namespace Fortran::common {
+
+template <class _Tp>
+using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
+template <class _Tp, class _Up>
+struct __is_same_uncvref
+ : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};
+
+#if STD_REFERENCE_WRAPPER_UNSUPPORTED
+template <class _Tp> class reference_wrapper {
+public:
+ // types
+ typedef _Tp type;
+
+private:
+ type *__f_;
+
+ static RT_API_ATTRS void __fun(_Tp &);
+ static void __fun(_Tp &&) = delete;
+
+public:
+ template <class _Up,
+ class =
+ std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
+ decltype(__fun(std::declval<_Up>()))>>
+ constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
+ type &__f = static_cast<_Up &&>(__u);
+ __f_ = std::addressof(__f);
+ }
+
+ // access
+ constexpr RT_API_ATTRS operator type &() const { return *__f_; }
+ constexpr RT_API_ATTRS type &get() const { return *__f_; }
+
+ // invoke
+ template <class... _ArgTypes>
+ constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
+ operator()(_ArgTypes &&...__args) const {
+ return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
+ }
+};
+
+template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
+ return reference_wrapper<_Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
+ reference_wrapper<_Tp> __t) {
+ return __t;
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+ const _Tp &__t) {
+ return reference_wrapper<const _Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+ reference_wrapper<_Tp> __t) {
+ return __t;
+}
+
+template <class _Tp> void ref(const _Tp &&) = delete;
+template <class _Tp> void cref(const _Tp &&) = delete;
+#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+using std::cref;
+using std::ref;
+using std::reference_wrapper;
+#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+
+} // namespace Fortran::common
+
+#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 0477c32b3b53ad..e00d54980aae59 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -17,6 +17,7 @@
#include "internal-unit.h"
#include "io-error.h"
#include "flang/Common/optional.h"
+#include "flang/Common/reference-wrapper.h"
#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
@@ -210,39 +211,47 @@ class IoStatementState {
}
private:
- std::variant<std::reference_wrapper<OpenStatementState>,
- std::reference_wrapper<CloseStatementState>,
- std::reference_wrapper<NoopStatementState>,
- std::reference_wrapper<
+ std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
+ Fortran::common::reference_wrapper<CloseStatementState>,
+ Fortran::common::reference_wrapper<NoopStatementState>,
+ Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
- std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
- std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ ExternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ExternalListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
- std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ ChildFormattedIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ChildFormattedIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
+ ChildListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ChildListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<InquireUnitState>,
- std::reference_wrapper<InquireNoUnitState>,
- std::reference_wrapper<InquireUnconnectedFileState>,
- std::reference_wrapper<InquireIOLengthState>,
- std::reference_wrapper<ExternalMiscIoStatementState>,
- std::reference_wrapper<ErroneousIoStatementState>>
+ Fortran::common::reference_wrapper<InquireUnitState>,
+ Fortran::common::reference_wrapper<InquireNoUnitState>,
+ Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
+ Fortran::common::reference_wrapper<InquireIOLengthState>,
+ Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
+ Fortran::common::reference_wrapper<ErroneousIoStatementState>>
u_;
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/85178
More information about the llvm-branch-commits
mailing list