[llvm-branch-commits] [flang][runtime] Added Fortran::common::reference_wrapper for use on device. (PR #85178)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Mar 13 22:32:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.


---
Full diff: https://github.com/llvm/llvm-project/pull/85178.diff


2 Files Affected:

- (added) flang/include/flang/Common/reference-wrapper.h (+114) 
- (modified) flang/runtime/io-stmt.h (+34-25) 


``````````diff
diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h
new file mode 100644
index 00000000000000..66f924662d9612
--- /dev/null
+++ b/flang/include/flang/Common/reference-wrapper.h
@@ -0,0 +1,114 @@
+//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// clang-format off
+//
+// Implementation of std::reference_wrapper borrowed from libcu++
+// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
+// with modifications.
+//
+// The original source code is distributed under the Apache License v2.0
+// with LLVM Exceptions.
+//
+// TODO: using libcu++ is the best option for CUDA, but there is a couple
+// of issues:
+//   * The include paths need to be set up such that all STD header files
+//     are taken from libcu++.
+//   * cuda:: namespace need to be forced for all std:: references.
+//
+// clang-format on
+
+#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
+#define FORTRAN_COMMON_REFERENCE_WRAPPER_H
+
+#include "flang/Runtime/api-attrs.h"
+#include <functional>
+#include <type_traits>
+
+#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
+    (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
+#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
+#endif
+
+namespace Fortran::common {
+
+template <class _Tp>
+using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
+template <class _Tp, class _Up>
+struct __is_same_uncvref
+    : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};
+
+#if STD_REFERENCE_WRAPPER_UNSUPPORTED
+template <class _Tp> class reference_wrapper {
+public:
+  // types
+  typedef _Tp type;
+
+private:
+  type *__f_;
+
+  static RT_API_ATTRS void __fun(_Tp &);
+  static void __fun(_Tp &&) = delete;
+
+public:
+  template <class _Up,
+      class =
+          std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
+              decltype(__fun(std::declval<_Up>()))>>
+  constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
+    type &__f = static_cast<_Up &&>(__u);
+    __f_ = std::addressof(__f);
+  }
+
+  // access
+  constexpr RT_API_ATTRS operator type &() const { return *__f_; }
+  constexpr RT_API_ATTRS type &get() const { return *__f_; }
+
+  // invoke
+  template <class... _ArgTypes>
+  constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
+  operator()(_ArgTypes &&...__args) const {
+    return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
+  }
+};
+
+template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
+  return reference_wrapper<_Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
+    reference_wrapper<_Tp> __t) {
+  return __t;
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+    const _Tp &__t) {
+  return reference_wrapper<const _Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+    reference_wrapper<_Tp> __t) {
+  return __t;
+}
+
+template <class _Tp> void ref(const _Tp &&) = delete;
+template <class _Tp> void cref(const _Tp &&) = delete;
+#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+using std::cref;
+using std::ref;
+using std::reference_wrapper;
+#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+
+} // namespace Fortran::common
+
+#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 0477c32b3b53ad..e00d54980aae59 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -17,6 +17,7 @@
 #include "internal-unit.h"
 #include "io-error.h"
 #include "flang/Common/optional.h"
+#include "flang/Common/reference-wrapper.h"
 #include "flang/Common/visit.h"
 #include "flang/Runtime/descriptor.h"
 #include "flang/Runtime/io-api.h"
@@ -210,39 +211,47 @@ class IoStatementState {
   }
 
 private:
-  std::variant<std::reference_wrapper<OpenStatementState>,
-      std::reference_wrapper<CloseStatementState>,
-      std::reference_wrapper<NoopStatementState>,
-      std::reference_wrapper<
+  std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
+      Fortran::common::reference_wrapper<CloseStatementState>,
+      Fortran::common::reference_wrapper<NoopStatementState>,
+      Fortran::common::reference_wrapper<
           InternalFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           InternalFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          InternalListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          InternalListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ExternalFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ExternalFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          ExternalListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ExternalListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ExternalUnformattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ExternalUnformattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          ChildFormattedIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ChildFormattedIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
+          ChildListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ChildListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ChildUnformattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ChildUnformattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<InquireUnitState>,
-      std::reference_wrapper<InquireNoUnitState>,
-      std::reference_wrapper<InquireUnconnectedFileState>,
-      std::reference_wrapper<InquireIOLengthState>,
-      std::reference_wrapper<ExternalMiscIoStatementState>,
-      std::reference_wrapper<ErroneousIoStatementState>>
+      Fortran::common::reference_wrapper<InquireUnitState>,
+      Fortran::common::reference_wrapper<InquireNoUnitState>,
+      Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
+      Fortran::common::reference_wrapper<InquireIOLengthState>,
+      Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
+      Fortran::common::reference_wrapper<ErroneousIoStatementState>>
       u_;
 };
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/85178


More information about the llvm-branch-commits mailing list