[flang-commits] [flang] [flang][runtime] Accept 128-bit integer SHIFT values in CSHIFT/EOSHIFT (PR #75246)
via flang-commits
flang-commits at lists.llvm.org
Tue Dec 12 13:28:40 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
It would surprise me if this case ever arose outside a couple of tests in llvm-test-suite/Fortran/gfortran/regression (namely cshift_large_1.f90 and eoshift_large_1.f90), but now at least those tests will pass.
---
Full diff: https://github.com/llvm/llvm-project/pull/75246.diff
2 Files Affected:
- (modified) flang/runtime/tools.h (+25)
- (modified) flang/runtime/transformational.cpp (+20-13)
``````````diff
diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index ea659190e14391..c1da6898e881d3 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64(
}
}
+static inline RT_API_ATTRS std::optional<std::int64_t> GetInt64Safe(
+ const char *p, std::size_t bytes, Terminator &terminator) {
+ switch (bytes) {
+ case 1:
+ return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 1> *>(p);
+ case 2:
+ return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 2> *>(p);
+ case 4:
+ return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 4> *>(p);
+ case 8:
+ return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 8> *>(p);
+ case 16: {
+ using Int128 = CppTypeFor<TypeCategory::Integer, 16>;
+ auto n{*reinterpret_cast<const Int128 *>(p)};
+ std::int64_t result = n;
+ if (result == n) {
+ return result;
+ }
+ return std::nullopt;
+ }
+ default:
+ terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes);
+ }
+}
+
template <typename INT>
inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) {
switch (kind) {
diff --git a/flang/runtime/transformational.cpp b/flang/runtime/transformational.cpp
index da8ec05c884fa3..f0e38f88d7de5b 100644
--- a/flang/runtime/transformational.cpp
+++ b/flang/runtime/transformational.cpp
@@ -52,9 +52,11 @@ class ShiftControl {
}
}
}
+ } else if (auto count{GetInt64Safe(
+ shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
+ shiftCount_ = *count;
} else {
- shiftCount_ =
- GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
+ terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
}
}
RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
@@ -67,8 +69,10 @@ class ShiftControl {
++k;
}
}
- return GetInt64(
- shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
+ auto count{GetInt64Safe(
+ shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
+ RUNTIME_CHECK(terminator_, count.has_value());
+ return *count;
} else {
return shiftCount_; // invariant count extracted in Init()
}
@@ -719,12 +723,13 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
std::size_t resultElements{1};
SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
- resultExtent[j] = GetInt64(
- shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
- if (resultExtent[j] < 0) {
+ auto extent{GetInt64Safe(
+ shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
+ if (!extent || *extent < 0) {
terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
static_cast<std::intmax_t>(resultExtent[j]));
}
+ resultExtent[j] = *extent;
resultElements *= resultExtent[j];
}
@@ -762,14 +767,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
std::size_t orderElementBytes{order->ElementBytes()};
for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
- auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
- terminator)};
- if (k < 1 || k > resultRank || ((values >> k) & 1)) {
+ auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
+ orderElementBytes, terminator)};
+ if (!k) {
+ terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
+ } else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
- static_cast<std::intmax_t>(k));
+ static_cast<std::intmax_t>(*k));
}
- values |= std::uint64_t{1} << k;
- dimOrder[j] = k - 1;
+ values |= std::uint64_t{1} << *k;
+ dimOrder[j] = *k - 1;
}
} else {
for (int j{0}; j < resultRank; ++j) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/75246
More information about the flang-commits
mailing list