[llvm-branch-commits] [flang][runtime] Added custom visitor for IoStatementState variants. (PR #85179)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Mar 13 22:32:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

The visitor only allows Internal.*IoStatementState variants to be visited.
In case another variant is met a runtime error is produced.
During the device compilation the other variants' classes are not referenced,
which, for example, helps to avoid warnings about __host__ only
methods referenced in __device__ code.

I had problems parameterizing the Fortran::common visitor to limit
the allowed variants, but I can give it another try if creating
a copy looks inappropriate.


---
Full diff: https://github.com/llvm/llvm-project/pull/85179.diff


2 Files Affected:

- (modified) flang/runtime/io-stmt.cpp (+20-27) 
- (modified) flang/runtime/io-stmt.h (+56-2) 


``````````diff
diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 075d7b5ae518a4..efefbc5e1a1c08 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -467,69 +467,66 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
 }
 
 Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) {
-  return common::visit(
-      [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
+  return visit([&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
 }
 
 bool IoStatementState::Emit(
     const char *data, std::size_t bytes, std::size_t elementBytes) {
-  return common::visit(
+  return visit(
       [=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_);
 }
 
 bool IoStatementState::Receive(
     char *data, std::size_t n, std::size_t elementBytes) {
-  return common::visit(
+  return visit(
       [=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
 }
 
 std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
-  return common::visit(
-      [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
+  return visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
 }
 
 bool IoStatementState::AdvanceRecord(int n) {
-  return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
+  return visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
 }
 
 void IoStatementState::BackspaceRecord() {
-  common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
+  visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
 }
 
 void IoStatementState::HandleRelativePosition(std::int64_t n) {
-  common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
+  visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
 }
 
 void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
-  common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
+  visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
 }
 
 void IoStatementState::CompleteOperation() {
-  common::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
+  visit([](auto &x) { x.get().CompleteOperation(); }, u_);
 }
 
 int IoStatementState::EndIoStatement() {
-  return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
+  return visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
 }
 
 ConnectionState &IoStatementState::GetConnectionState() {
-  return common::visit(
+  return visit(
       [](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); },
       u_);
 }
 
 MutableModes &IoStatementState::mutableModes() {
-  return common::visit(
+  return visit(
       [](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_);
 }
 
 bool IoStatementState::BeginReadingRecord() {
-  return common::visit(
-      [](auto &x) { return x.get().BeginReadingRecord(); }, u_);
+  return visit([](auto &x) { return x.get().BeginReadingRecord(); }, u_);
 }
 
 IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
-  return common::visit(
+  return visit(
       [](auto &x) -> IoErrorHandler & {
         return static_cast<IoErrorHandler &>(x.get());
       },
@@ -537,8 +534,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
 }
 
 ExternalFileUnit *IoStatementState::GetExternalFileUnit() const {
-  return common::visit(
-      [](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
+  return visit([](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
 }
 
 Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar(
@@ -664,28 +660,25 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) {
 
 bool IoStatementState::Inquire(
     InquiryKeywordHash inquiry, char *out, std::size_t chars) {
-  return common::visit(
+  return visit(
       [&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_);
 }
 
 bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
 }
 
 bool IoStatementState::Inquire(
     InquiryKeywordHash inquiry, std::int64_t id, bool &out) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
 }
 
 bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) {
-  return common::visit(
-      [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
+  return visit([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
 }
 
 std::int64_t IoStatementState::InquirePos() {
-  return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_);
+  return visit([&](auto &x) { return x.get().InquirePos(); }, u_);
 }
 
 void IoStatementState::GotChar(int n) {
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index e00d54980aae59..7fecf4d9e41754 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -18,7 +18,6 @@
 #include "io-error.h"
 #include "flang/Common/optional.h"
 #include "flang/Common/reference-wrapper.h"
-#include "flang/Common/visit.h"
 #include "flang/Runtime/descriptor.h"
 #include "flang/Runtime/io-api.h"
 #include <functional>
@@ -113,7 +112,7 @@ class IoStatementState {
 
   // N.B.: this also works with base classes
   template <typename A> A *get_if() const {
-    return common::visit(
+    return visit(
         [](auto &x) -> A * {
           if constexpr (std::is_convertible_v<decltype(x.get()), A &>) {
             return &x.get();
@@ -211,6 +210,61 @@ class IoStatementState {
   }
 
 private:
+  // Define special visitor for the variants of IoStatementState.
+  // During the device code compilation the visitor only allows
+  // visiting those variants that are supported on the device.
+  // In particular, only the internal IO variants are supported.
+  // TODO: parameterize Fortran::common::log2visit instead of
+  //       creating a copy here.
+  template <class T, class... Ts>
+  struct is_any_type : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
+
+  template <std::size_t LOW, std::size_t HIGH, typename RESULT,
+      typename VISITOR, typename VARIANT>
+  static inline RT_API_ATTRS RESULT Log2VisitHelper(
+      VISITOR &&visitor, std::size_t which, VARIANT &&u) {
+#if !defined(RT_DEVICE_COMPILATION)
+    constexpr bool isDevice{false};
+#else
+    constexpr bool isDevice{true};
+#endif
+    if constexpr (LOW == HIGH) {
+      if constexpr (!isDevice ||
+          is_any_type<
+              std::variant_alternative_t<LOW, std::decay_t<decltype(u)>>,
+              Fortran::common::reference_wrapper<
+                  InternalListIoStatementState<Direction::Output>>,
+              Fortran::common::reference_wrapper<
+                  InternalFormattedIoStatementState<Direction::Output>>>::
+              value) {
+        return visitor(std::get<LOW>(std::forward<VARIANT>(u)));
+      } else {
+        Terminator{__FILE__, __LINE__}.Crash(
+            "not implemented yet: IoStatementState variant %d\n",
+            static_cast<int>(LOW));
+      }
+    } else {
+      static constexpr std::size_t mid{(HIGH + LOW) / 2};
+      if (which <= mid) {
+        return Log2VisitHelper<LOW, mid, RESULT>(
+            std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+      } else {
+        return Log2VisitHelper<(mid + 1), HIGH, RESULT>(
+            std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+      }
+    }
+  }
+
+  template <typename VISITOR, typename VARIANT>
+  static inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&u)
+      -> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) {
+    using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))));
+    static constexpr std::size_t high{
+        std::variant_size_v<std::decay_t<decltype(u)>> - 1};
+    return Log2VisitHelper<0, high, Result>(
+        std::forward<VISITOR>(visitor), u.index(), std::forward<VARIANT>(u));
+  }
+
   std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
       Fortran::common::reference_wrapper<CloseStatementState>,
       Fortran::common::reference_wrapper<NoopStatementState>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/85179


More information about the llvm-branch-commits mailing list