[llvm-branch-commits] [flang][runtime] Added custom visitor for IoStatementState variants. (PR #85179)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 13 22:32:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
The visitor only allows Internal.*IoStatementState variants to be visited.
In case another variant is met a runtime error is produced.
During the device compilation the other variants' classes are not referenced,
which, for example, helps to avoid warnings about __host__ only
methods referenced in __device__ code.
I had problems parameterizing the Fortran::common visitor to limit
the allowed variants, but I can give it another try if creating
a copy looks inappropriate.
---
Full diff: https://github.com/llvm/llvm-project/pull/85179.diff
2 Files Affected:
- (modified) flang/runtime/io-stmt.cpp (+20-27)
- (modified) flang/runtime/io-stmt.h (+56-2)
``````````diff
diff --git a/flang/runtime/io-stmt.cpp b/flang/runtime/io-stmt.cpp
index 075d7b5ae518a4..efefbc5e1a1c08 100644
--- a/flang/runtime/io-stmt.cpp
+++ b/flang/runtime/io-stmt.cpp
@@ -467,69 +467,66 @@ int ExternalFormattedIoStatementState<DIR, CHAR>::EndIoStatement() {
}
Fortran::common::optional<DataEdit> IoStatementState::GetNextDataEdit(int n) {
- return common::visit(
- [&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
+ return visit([&](auto &x) { return x.get().GetNextDataEdit(*this, n); }, u_);
}
bool IoStatementState::Emit(
const char *data, std::size_t bytes, std::size_t elementBytes) {
- return common::visit(
+ return visit(
[=](auto &x) { return x.get().Emit(data, bytes, elementBytes); }, u_);
}
bool IoStatementState::Receive(
char *data, std::size_t n, std::size_t elementBytes) {
- return common::visit(
+ return visit(
[=](auto &x) { return x.get().Receive(data, n, elementBytes); }, u_);
}
std::size_t IoStatementState::GetNextInputBytes(const char *&p) {
- return common::visit(
- [&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
+ return visit([&](auto &x) { return x.get().GetNextInputBytes(p); }, u_);
}
bool IoStatementState::AdvanceRecord(int n) {
- return common::visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
+ return visit([=](auto &x) { return x.get().AdvanceRecord(n); }, u_);
}
void IoStatementState::BackspaceRecord() {
- common::visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
+ visit([](auto &x) { x.get().BackspaceRecord(); }, u_);
}
void IoStatementState::HandleRelativePosition(std::int64_t n) {
- common::visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
+ visit([=](auto &x) { x.get().HandleRelativePosition(n); }, u_);
}
void IoStatementState::HandleAbsolutePosition(std::int64_t n) {
- common::visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
+ visit([=](auto &x) { x.get().HandleAbsolutePosition(n); }, u_);
}
void IoStatementState::CompleteOperation() {
- common::visit([](auto &x) { x.get().CompleteOperation(); }, u_);
+ visit([](auto &x) { x.get().CompleteOperation(); }, u_);
}
int IoStatementState::EndIoStatement() {
- return common::visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
+ return visit([](auto &x) { return x.get().EndIoStatement(); }, u_);
}
ConnectionState &IoStatementState::GetConnectionState() {
- return common::visit(
+ return visit(
[](auto &x) -> ConnectionState & { return x.get().GetConnectionState(); },
u_);
}
MutableModes &IoStatementState::mutableModes() {
- return common::visit(
+ return visit(
[](auto &x) -> MutableModes & { return x.get().mutableModes(); }, u_);
}
bool IoStatementState::BeginReadingRecord() {
- return common::visit(
- [](auto &x) { return x.get().BeginReadingRecord(); }, u_);
+ return visit([](auto &x) { return x.get().BeginReadingRecord(); }, u_);
}
IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
- return common::visit(
+ return visit(
[](auto &x) -> IoErrorHandler & {
return static_cast<IoErrorHandler &>(x.get());
},
@@ -537,8 +534,7 @@ IoErrorHandler &IoStatementState::GetIoErrorHandler() const {
}
ExternalFileUnit *IoStatementState::GetExternalFileUnit() const {
- return common::visit(
- [](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
+ return visit([](auto &x) { return x.get().GetExternalFileUnit(); }, u_);
}
Fortran::common::optional<char32_t> IoStatementState::GetCurrentChar(
@@ -664,28 +660,25 @@ bool IoStatementState::CheckForEndOfRecord(std::size_t afterReading) {
bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, char *out, std::size_t chars) {
- return common::visit(
+ return visit(
[&](auto &x) { return x.get().Inquire(inquiry, out, chars); }, u_);
}
bool IoStatementState::Inquire(InquiryKeywordHash inquiry, bool &out) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, out); }, u_);
}
bool IoStatementState::Inquire(
InquiryKeywordHash inquiry, std::int64_t id, bool &out) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, id, out); }, u_);
}
bool IoStatementState::Inquire(InquiryKeywordHash inquiry, std::int64_t &n) {
- return common::visit(
- [&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
+ return visit([&](auto &x) { return x.get().Inquire(inquiry, n); }, u_);
}
std::int64_t IoStatementState::InquirePos() {
- return common::visit([&](auto &x) { return x.get().InquirePos(); }, u_);
+ return visit([&](auto &x) { return x.get().InquirePos(); }, u_);
}
void IoStatementState::GotChar(int n) {
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index e00d54980aae59..7fecf4d9e41754 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -18,7 +18,6 @@
#include "io-error.h"
#include "flang/Common/optional.h"
#include "flang/Common/reference-wrapper.h"
-#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
#include <functional>
@@ -113,7 +112,7 @@ class IoStatementState {
// N.B.: this also works with base classes
template <typename A> A *get_if() const {
- return common::visit(
+ return visit(
[](auto &x) -> A * {
if constexpr (std::is_convertible_v<decltype(x.get()), A &>) {
return &x.get();
@@ -211,6 +210,61 @@ class IoStatementState {
}
private:
+ // Define special visitor for the variants of IoStatementState.
+ // During the device code compilation the visitor only allows
+ // visiting those variants that are supported on the device.
+ // In particular, only the internal IO variants are supported.
+ // TODO: parameterize Fortran::common::log2visit instead of
+ // creating a copy here.
+ template <class T, class... Ts>
+ struct is_any_type : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
+
+ template <std::size_t LOW, std::size_t HIGH, typename RESULT,
+ typename VISITOR, typename VARIANT>
+ static inline RT_API_ATTRS RESULT Log2VisitHelper(
+ VISITOR &&visitor, std::size_t which, VARIANT &&u) {
+#if !defined(RT_DEVICE_COMPILATION)
+ constexpr bool isDevice{false};
+#else
+ constexpr bool isDevice{true};
+#endif
+ if constexpr (LOW == HIGH) {
+ if constexpr (!isDevice ||
+ is_any_type<
+ std::variant_alternative_t<LOW, std::decay_t<decltype(u)>>,
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ InternalFormattedIoStatementState<Direction::Output>>>::
+ value) {
+ return visitor(std::get<LOW>(std::forward<VARIANT>(u)));
+ } else {
+ Terminator{__FILE__, __LINE__}.Crash(
+ "not implemented yet: IoStatementState variant %d\n",
+ static_cast<int>(LOW));
+ }
+ } else {
+ static constexpr std::size_t mid{(HIGH + LOW) / 2};
+ if (which <= mid) {
+ return Log2VisitHelper<LOW, mid, RESULT>(
+ std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+ } else {
+ return Log2VisitHelper<(mid + 1), HIGH, RESULT>(
+ std::forward<VISITOR>(visitor), which, std::forward<VARIANT>(u));
+ }
+ }
+ }
+
+ template <typename VISITOR, typename VARIANT>
+ static inline RT_API_ATTRS auto visit(VISITOR &&visitor, VARIANT &&u)
+ -> decltype(visitor(std::get<0>(std::forward<VARIANT>(u)))) {
+ using Result = decltype(visitor(std::get<0>(std::forward<VARIANT>(u))));
+ static constexpr std::size_t high{
+ std::variant_size_v<std::decay_t<decltype(u)>> - 1};
+ return Log2VisitHelper<0, high, Result>(
+ std::forward<VISITOR>(visitor), u.index(), std::forward<VARIANT>(u));
+ }
+
std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
Fortran::common::reference_wrapper<CloseStatementState>,
Fortran::common::reference_wrapper<NoopStatementState>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/85179
More information about the llvm-branch-commits
mailing list