[flang-commits] [flang] [flang][runtime] Accept 128-bit integer SHIFT values in CSHIFT/EOSHIFT (PR #75246)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Dec 26 12:57:24 PST 2023


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/75246

>From bc1d9c00a94ac95a3400faa976d4b714cc4350fb Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 12 Dec 2023 13:25:25 -0800
Subject: [PATCH] [flang][runtime] Accept 128-bit integer SHIFT values in
 CSHIFT/EOSHIFT

It would surprise me if this case ever arose outside a couple of
tests in llvm-test-suite/Fortran/gfortran/regression
(namely cshift_large_1.f90 and eoshift_large_1.f90), but now
at least those tests will pass.
---
 flang/runtime/tools.h              | 25 ++++++++++++++++++++
 flang/runtime/transformational.cpp | 37 +++++++++++++++++++-----------
 2 files changed, 48 insertions(+), 14 deletions(-)

diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index 9811bce25acd3c..ff05e76c8bb7b2 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64(
   }
 }
 
+static inline RT_API_ATTRS std::optional<std::int64_t> GetInt64Safe(
+    const char *p, std::size_t bytes, Terminator &terminator) {
+  switch (bytes) {
+  case 1:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 1> *>(p);
+  case 2:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 2> *>(p);
+  case 4:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 4> *>(p);
+  case 8:
+    return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 8> *>(p);
+  case 16: {
+    using Int128 = CppTypeFor<TypeCategory::Integer, 16>;
+    auto n{*reinterpret_cast<const Int128 *>(p)};
+    std::int64_t result = n;
+    if (result == n) {
+      return result;
+    }
+    return std::nullopt;
+  }
+  default:
+    terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes);
+  }
+}
+
 template <typename INT>
 inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) {
   switch (kind) {
diff --git a/flang/runtime/transformational.cpp b/flang/runtime/transformational.cpp
index da8ec05c884fa3..cf1e61c0844d82 100644
--- a/flang/runtime/transformational.cpp
+++ b/flang/runtime/transformational.cpp
@@ -52,9 +52,11 @@ class ShiftControl {
           }
         }
       }
+    } else if (auto count{GetInt64Safe(
+                   shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
+      shiftCount_ = *count;
     } else {
-      shiftCount_ =
-          GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
+      terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
     }
   }
   RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
@@ -67,8 +69,10 @@ class ShiftControl {
           ++k;
         }
       }
-      return GetInt64(
-          shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
+      auto count{GetInt64Safe(
+          shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
+      RUNTIME_CHECK(terminator_, count.has_value());
+      return *count;
     } else {
       return shiftCount_; // invariant count extracted in Init()
     }
@@ -719,12 +723,15 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
   std::size_t resultElements{1};
   SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
   for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
-    resultExtent[j] = GetInt64(
-        shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
-    if (resultExtent[j] < 0) {
+    auto extent{GetInt64Safe(
+        shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
+    if (!extent) {
+      terminator.Crash("RESHAPE: value of SHAPE(%d) exceeds 64 bits", j + 1);
+    } else if (*extent < 0) {
       terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
-          static_cast<std::intmax_t>(resultExtent[j]));
+          static_cast<std::intmax_t>(*extent));
     }
+    resultExtent[j] = *extent;
     resultElements *= resultExtent[j];
   }
 
@@ -762,14 +769,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
     SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
     std::size_t orderElementBytes{order->ElementBytes()};
     for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
-      auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
-          terminator)};
-      if (k < 1 || k > resultRank || ((values >> k) & 1)) {
+      auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
+          orderElementBytes, terminator)};
+      if (!k) {
+        terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
+      } else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
         terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
-            static_cast<std::intmax_t>(k));
+            static_cast<std::intmax_t>(*k));
       }
-      values |= std::uint64_t{1} << k;
-      dimOrder[j] = k - 1;
+      values |= std::uint64_t{1} << *k;
+      dimOrder[j] = *k - 1;
     }
   } else {
     for (int j{0}; j < resultRank; ++j) {



More information about the flang-commits mailing list