[flang-commits] [flang] [flang][runtime] Accept 128-bit integer SHIFT values in CSHIFT/EOSHIFT (PR #75246)

via flang-commits flang-commits at lists.llvm.org
Tue Dec 12 13:28:40 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-runtime

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

It would surprise me if this case ever arose outside a couple of tests in llvm-test-suite/Fortran/gfortran/regression (namely cshift_large_1.f90 and eoshift_large_1.f90), but now at least those tests will pass.

---
Full diff: https://github.com/llvm/llvm-project/pull/75246.diff


2 Files Affected:

- (modified) flang/runtime/tools.h (+25) 
- (modified) flang/runtime/transformational.cpp (+20-13) 


``````````diff
diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index ea659190e14391..c1da6898e881d3 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64(
   }
 }
 
+static inline RT_API_ATTRS std::optional<std::int64_t> GetInt64Safe(
+    const char *p, std::size_t bytes, Terminator &terminator) {
+  switch (bytes) {
+  case 1:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 1> *>(p);
+  case 2:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 2> *>(p);
+  case 4:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 4> *>(p);
+  case 8:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 8> *>(p);
+  case 16: {
+    using Int128 = CppTypeFor<TypeCategory::Integer, 16>;
+    auto n{*reinterpret_cast<const Int128 *>(p)};
+    std::int64_t result = n;
+    if (result == n) {
+      return result;
+    }
+    return std::nullopt;
+  }
+  default:
+    terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes);
+  }
+}
+
 template <typename INT>
 inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) {
   switch (kind) {
diff --git a/flang/runtime/transformational.cpp b/flang/runtime/transformational.cpp
index da8ec05c884fa3..f0e38f88d7de5b 100644
--- a/flang/runtime/transformational.cpp
+++ b/flang/runtime/transformational.cpp
@@ -52,9 +52,11 @@ class ShiftControl {
           }
         }
       }
+    } else if (auto count{GetInt64Safe(
+                   shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
+      shiftCount_ = *count;
     } else {
-      shiftCount_ =
-          GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
+      terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
     }
   }
   RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
@@ -67,8 +69,10 @@ class ShiftControl {
           ++k;
         }
       }
-      return GetInt64(
-          shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
+      auto count{GetInt64Safe(
+          shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
+      RUNTIME_CHECK(terminator_, count.has_value());
+      return *count;
     } else {
       return shiftCount_; // invariant count extracted in Init()
     }
@@ -719,12 +723,13 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
   std::size_t resultElements{1};
   SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
   for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
-    resultExtent[j] = GetInt64(
-        shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
-    if (resultExtent[j] < 0) {
+    auto extent{GetInt64Safe(
+        shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
+    if (!extent || *extent < 0) {
       terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
           static_cast<std::intmax_t>(resultExtent[j]));
     }
+    resultExtent[j] = *extent;
     resultElements *= resultExtent[j];
   }
 
@@ -762,14 +767,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
     SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
     std::size_t orderElementBytes{order->ElementBytes()};
     for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
-      auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
-          terminator)};
-      if (k < 1 || k > resultRank || ((values >> k) & 1)) {
+      auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
+          orderElementBytes, terminator)};
+      if (!k) {
+        terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
+      } else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
         terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
-            static_cast<std::intmax_t>(k));
+            static_cast<std::intmax_t>(*k));
       }
-      values |= std::uint64_t{1} << k;
-      dimOrder[j] = k - 1;
+      values |= std::uint64_t{1} << *k;
+      dimOrder[j] = *k - 1;
     }
   } else {
     for (int j{0}; j < resultRank; ++j) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/75246


More information about the flang-commits mailing list