[flang-commits] [flang] 4244cab - [flang] Check constant arguments to bit manipulation intrinsics even if not foldable

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Sat Dec 17 15:25:07 PST 2022


Author: Peter Klausler
Date: 2022-12-17T15:24:59-08:00
New Revision: 4244cab23afcee5c3515d67d6d340cf82ce5289f

URL: https://github.com/llvm/llvm-project/commit/4244cab23afcee5c3515d67d6d340cf82ce5289f
DIFF: https://github.com/llvm/llvm-project/commit/4244cab23afcee5c3515d67d6d340cf82ce5289f.diff

LOG: [flang] Check constant arguments to bit manipulation intrinsics even if not foldable

When some arguments that specify bit positions, shift counts, and field sizes are
constant at compilation time, but other arguments are not constant, the compiler
should still validate the constant ones.  In the current sources, validation is
only performed for intrinsic references that can be folded to constants.

Differential Revision: https://reviews.llvm.org/D140152

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Evaluate/intrinsics.cpp
    flang/test/Evaluate/fold-ishftc.f90
    flang/test/Semantics/ishftc.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 0aaf5b182635e..61e76b59b1f9b 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -577,6 +577,21 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
         name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
     // Third argument can be of any kind. However, it must be smaller or equal
     // than BIT_SIZE. It can be converted to Int4 to simplify.
+    if (const auto *shiftCon{Folder<Int4>(context).Folding(args[2])}) {
+      for (const auto &scalar : shiftCon->values()) {
+        std::int64_t shiftVal{scalar.ToInt64()};
+        if (shiftVal < 0) {
+          context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
+              std::intmax_t{shiftVal}, name);
+          break;
+        } else if (shiftVal > T::Scalar::bits) {
+          context.messages().Say(
+              "SHIFT=%jd count for %s is greater than %d"_err_en_US,
+              std::intmax_t{shiftVal}, name, T::Scalar::bits);
+          break;
+        }
+      }
+    }
     return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
         ScalarFunc<T, T, T, Int4>(
             [&fptr](const Scalar<T> &i, const Scalar<T> &j,
@@ -662,42 +677,69 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     } else {
       common::die("missing case to fold intrinsic function %s", name.c_str());
     }
+    if (const auto *posCon{Folder<Int4>(context).Folding(args[1])}) {
+      for (const auto &scalar : posCon->values()) {
+        std::int64_t posVal{scalar.ToInt64()};
+        if (posVal < 0) {
+          context.messages().Say(
+              "bit position for %s (%jd) is negative"_err_en_US, name,
+              std::intmax_t{posVal});
+          break;
+        } else if (posVal >= T::Scalar::bits) {
+          context.messages().Say(
+              "bit position for %s (%jd) is not less than %d"_err_en_US, name,
+              std::intmax_t{posVal}, T::Scalar::bits);
+          break;
+        }
+      }
+    }
     return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
-        ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
-                                   const Scalar<Int4> &pos) -> Scalar<T> {
-          auto posVal{static_cast<int>(pos.ToInt64())};
-          if (posVal < 0) {
-            context.messages().Say(
-                "bit position for %s (%d) is negative"_err_en_US, name, posVal);
-          } else if (posVal >= i.bits) {
-            context.messages().Say(
-                "bit position for %s (%d) is not less than %d"_err_en_US, name,
-                posVal, i.bits);
-          }
-          return std::invoke(fptr, i, posVal);
-        }));
+        ScalarFunc<T, T, Int4>(
+            [&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> {
+              return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
+            }));
   } else if (name == "ibits") {
+    const auto *posCon{Folder<Int4>(context).Folding(args[1])};
+    const auto *lenCon{Folder<Int4>(context).Folding(args[2])};
+    if (posCon && lenCon &&
+        (posCon->size() == 1 || lenCon->size() == 1 ||
+            posCon->size() == lenCon->size())) {
+      auto posIter{posCon->values().begin()};
+      auto lenIter{lenCon->values().begin()};
+      for (; posIter != posCon->values().end() &&
+           lenIter != lenCon->values().end();
+           ++posIter, ++lenIter) {
+        posIter = posIter == posCon->values().end() ? posCon->values().begin()
+                                                    : posIter;
+        lenIter = lenIter == lenCon->values().end() ? lenCon->values().begin()
+                                                    : lenIter;
+        auto posVal{static_cast<int>(posIter->ToInt64())};
+        auto lenVal{static_cast<int>(lenIter->ToInt64())};
+        if (posVal < 0) {
+          context.messages().Say(
+              "bit position for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US,
+              std::intmax_t{posVal}, std::intmax_t{lenVal});
+          break;
+        } else if (lenVal < 0) {
+          context.messages().Say(
+              "bit length for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US,
+              std::intmax_t{posVal}, std::intmax_t{lenVal});
+          break;
+        } else if (posVal + lenVal > T::Scalar::bits) {
+          context.messages().Say(
+              "IBITS(POS=%jd,LEN=%jd) must have POS+LEN no greater than %d"_err_en_US,
+              std::intmax_t{posVal}, std::intmax_t{lenVal}, T::Scalar::bits);
+          break;
+        }
+      }
+    }
     return FoldElementalIntrinsic<T, T, Int4, Int4>(context, std::move(funcRef),
-        ScalarFunc<T, T, Int4, Int4>([&](const Scalar<T> &i,
-                                         const Scalar<Int4> &pos,
-                                         const Scalar<Int4> &len) -> Scalar<T> {
-          auto posVal{static_cast<int>(pos.ToInt64())};
-          auto lenVal{static_cast<int>(len.ToInt64())};
-          if (posVal < 0) {
-            context.messages().Say(
-                "bit position for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
-                posVal, lenVal);
-          } else if (lenVal < 0) {
-            context.messages().Say(
-                "bit length for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
-                posVal, lenVal);
-          } else if (posVal + lenVal > i.bits) {
-            context.messages().Say(
-                "IBITS(POS=%d,LEN=%d) must have POS+LEN no greater than %d"_err_en_US,
-                posVal + lenVal, i.bits);
-          }
-          return i.IBITS(posVal, lenVal);
-        }));
+        ScalarFunc<T, T, Int4, Int4>(
+            [&](const Scalar<T> &i, const Scalar<Int4> &pos,
+                const Scalar<Int4> &len) -> Scalar<T> {
+              return i.IBITS(static_cast<int>(pos.ToInt64()),
+                  static_cast<int>(len.ToInt64()));
+            }));
   } else if (name == "index" || name == "scan" || name == "verify") {
     if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
       return common::visit(
@@ -761,24 +803,79 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "iparity") {
     return FoldBitReduction(
         context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
-  } else if (name == "ishft") {
-    return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
-        ScalarFunc<T, T, Int4>(
-            [&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> {
-              auto posVal{static_cast<int>(pos.ToInt64())};
-              if (posVal < -i.bits) {
-                context.messages().Say(
-                    "SHIFT=%d count for ishft is less than %d"_err_en_US,
-                    posVal, -i.bits);
-              } else if (posVal > i.bits) {
-                context.messages().Say(
-                    "SHIFT=%d count for ishft is greater than %d"_err_en_US,
-                    posVal, i.bits);
-              }
-              return i.ISHFT(posVal);
-            }));
-  } else if (name == "ishftc") {
-    if (args.at(2)) { // SIZE= is present
+  } else if (name == "ishft" || name == "ishftc") {
+    const auto *shiftCon{Folder<Int4>(context).Folding(args[1])};
+    if (shiftCon) {
+      for (const auto &scalar : shiftCon->values()) {
+        std::int64_t shiftVal{scalar.ToInt64()};
+        if (shiftVal < -T::Scalar::bits) {
+          context.messages().Say(
+              "SHIFT=%jd count for %s is less than %d"_err_en_US,
+              std::intmax_t{shiftVal}, name, -T::Scalar::bits);
+          break;
+        } else if (shiftVal > T::Scalar::bits) {
+          context.messages().Say(
+              "SHIFT=%jd count for %s is greater than %d"_err_en_US,
+              std::intmax_t{shiftVal}, name, T::Scalar::bits);
+          break;
+        }
+      }
+    }
+    if (args.size() == 3) { // ISHFTC
+      if (const auto *sizeCon{Folder<Int4>(context).Folding(args[2])}) {
+        for (const auto &scalar : sizeCon->values()) {
+          std::int64_t sizeVal{scalar.ToInt64()};
+          if (sizeVal <= 0) {
+            context.messages().Say(
+                "SIZE=%jd count for ishftc is not positive"_err_en_US,
+                std::intmax_t{sizeVal}, name);
+            break;
+          } else if (sizeVal > T::Scalar::bits) {
+            context.messages().Say(
+                "SIZE=%jd count for ishftc is greater than %d"_err_en_US,
+                std::intmax_t{sizeVal}, T::Scalar::bits);
+            break;
+          }
+        }
+        if (shiftCon &&
+            (shiftCon->size() == 1 || sizeCon->size() == 1 ||
+                shiftCon->size() == sizeCon->size())) {
+          auto shiftIter{shiftCon->values().begin()};
+          auto sizeIter{sizeCon->values().begin()};
+          for (; shiftIter != shiftCon->values().end() &&
+               sizeIter != sizeCon->values().end();
+               ++shiftIter, ++sizeIter) {
+            shiftIter = shiftIter == shiftCon->values().end()
+                ? shiftCon->values().begin()
+                : shiftIter;
+            sizeIter = sizeIter == sizeCon->values().end()
+                ? sizeCon->values().begin()
+                : sizeIter;
+            auto shiftVal{static_cast<int>(shiftIter->ToInt64())};
+            auto sizeVal{static_cast<int>(sizeIter->ToInt64())};
+            if (sizeVal > 0 && std::abs(shiftVal) > sizeVal) {
+              context.messages().Say(
+                  "SHIFT=%jd count for ishftc is greater in magnitude than SIZE=%jd"_err_en_US,
+                  std::intmax_t{shiftVal}, std::intmax_t{sizeVal});
+              break;
+            }
+          }
+        }
+      }
+    }
+    if (name == "ishft") {
+      return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
+          ScalarFunc<T, T, Int4>(
+              [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
+                return i.ISHFT(static_cast<int>(shift.ToInt64()));
+              }));
+    } else if (!args.at(2)) { // ISHFTC(no SIZE=)
+      return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
+          ScalarFunc<T, T, Int4>(
+              [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
+                return i.ISHFTC(static_cast<int>(shift.ToInt64()));
+              }));
+    } else { // ISHFTC(with SIZE=)
       return FoldElementalIntrinsic<T, T, Int4, Int4>(context,
           std::move(funcRef),
           ScalarFunc<T, T, Int4, Int4>(
@@ -789,13 +886,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
                 auto sizeVal{static_cast<int>(size.ToInt64())};
                 return i.ISHFTC(shiftVal, sizeVal);
               }));
-    } else { // no SIZE=
-      return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
-          ScalarFunc<T, T, Int4>(
-              [&](const Scalar<T> &i, const Scalar<Int4> &count) -> Scalar<T> {
-                auto countVal{static_cast<int>(count.ToInt64())};
-                return i.ISHFTC(countVal);
-              }));
     }
   } else if (name == "izext" || name == "jzext") {
     if (args.size() == 1) {
@@ -1045,20 +1135,26 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     } else {
       common::die("missing case to fold intrinsic function %s", name.c_str());
     }
+    if (const auto *shiftCon{Folder<Int4>(context).Folding(args[1])}) {
+      for (const auto &scalar : shiftCon->values()) {
+        std::int64_t shiftVal{scalar.ToInt64()};
+        if (shiftVal < 0) {
+          context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
+              std::intmax_t{shiftVal}, name, -T::Scalar::bits);
+          break;
+        } else if (shiftVal > T::Scalar::bits) {
+          context.messages().Say(
+              "SHIFT=%jd count for %s is greater than %d"_err_en_US,
+              std::intmax_t{shiftVal}, name, T::Scalar::bits);
+          break;
+        }
+      }
+    }
     return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
-        ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
-                                   const Scalar<Int4> &pos) -> Scalar<T> {
-          auto posVal{static_cast<int>(pos.ToInt64())};
-          if (posVal < 0) {
-            context.messages().Say(
-                "SHIFT=%d count for %s is negative"_err_en_US, posVal, name);
-          } else if (posVal > i.bits) {
-            context.messages().Say(
-                "SHIFT=%d count for %s is greater than %d"_err_en_US, posVal,
-                name, i.bits);
-          }
-          return std::invoke(fptr, i, posVal);
-        }));
+        ScalarFunc<T, T, Int4>(
+            [&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
+              return std::invoke(fptr, i, static_cast<int>(shift.ToInt64()));
+            }));
   } else if (name == "sign") {
     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
         ScalarFunc<T, T, T>([&context](const Scalar<T> &j,

diff  --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp
index 308e3e9d61855..6f4914d8fa715 100644
--- a/flang/lib/Evaluate/intrinsics.cpp
+++ b/flang/lib/Evaluate/intrinsics.cpp
@@ -2918,24 +2918,6 @@ static bool ApplySpecificChecks(SpecificCall &call, FoldingContext &context) {
     if (const auto &arg{call.arguments[0]}) {
       ok = CheckForNonPositiveValues(context, *arg, name, "image");
     }
-  } else if (name == "ishftc") {
-    if (const auto &sizeArg{call.arguments[2]}) {
-      ok = CheckForNonPositiveValues(context, *sizeArg, name, "size");
-      if (ok) {
-        if (auto sizeVal{ToInt64(sizeArg->UnwrapExpr())}) {
-          if (const auto &shiftArg{call.arguments[1]}) {
-            if (auto shiftVal{ToInt64(shiftArg->UnwrapExpr())}) {
-              if (std::abs(*shiftVal) > *sizeVal) {
-                ok = false;
-                context.messages().Say(shiftArg->sourceLocation(),
-                    "The absolute value of the 'shift=' argument for intrinsic '%s' must be less than or equal to the 'size=' argument"_err_en_US,
-                    name);
-              }
-            }
-          }
-        }
-      }
-    }
   } else if (name == "lcobound") {
     return CheckDimAgainstCorank(call, context);
   } else if (name == "loc") {

diff  --git a/flang/test/Evaluate/fold-ishftc.f90 b/flang/test/Evaluate/fold-ishftc.f90
index 35c6bf3283b78..705134823956c 100644
--- a/flang/test/Evaluate/fold-ishftc.f90
+++ b/flang/test/Evaluate/fold-ishftc.f90
@@ -1,9 +1,9 @@
 ! RUN: %python %S/test_folding.py %s %flang_fc1
 ! Tests folding of ISHFTC
 module m
-  integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -9, 9)], 8)
-  integer, parameter :: expect1(*) = 256 + [128, 1, 2, 4, 8, 16, 32, 64, 128, &
-                                            1, 2, 4, 8, 16, 32, 64, 128, 1, 2]
+  integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -8, 8)], 8)
+  integer, parameter :: expect1(*) = 256 + [1, 2, 4, 8, 16, 32, 64, 128, &
+                                            1, 2, 4, 8, 16, 32, 64, 128, 1]
   logical, parameter :: test_1 = all(shift8s == expect1)
   integer, parameter :: sizes(*) = [(ishftc(257, ict, [(isz, isz = 1, 8)]), ict = -1, 1)]
   integer, parameter :: expect2(*) = 256 + [[1, 2, 4, 8, 16, 32, 64, 128], &

diff  --git a/flang/test/Semantics/ishftc.f90 b/flang/test/Semantics/ishftc.f90
index 3e0ebe5a41d06..15d1213999cc9 100644
--- a/flang/test/Semantics/ishftc.f90
+++ b/flang/test/Semantics/ishftc.f90
@@ -18,31 +18,31 @@ program test_ishftc
   n = ishftc(3, 2, 3)
   array_result = ishftc([3,3], [2,2], [3,3])
 
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is -3
+  !ERROR: SIZE=-3 count for ishftc is not positive
   n = ishftc(3, 2, -3)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is 0
+  !ERROR: SIZE=0 count for ishftc is not positive
   n = ishftc(3, 2, 0)
-  !ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument
+  !ERROR: SHIFT=2 count for ishftc is greater in magnitude than SIZE=1
   n = ishftc(3, 2, 1)
-  !ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument
+  !ERROR: SHIFT=-2 count for ishftc is greater in magnitude than SIZE=1
   n = ishftc(3, -2, 1)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=-3 count for ishftc is not positive
   array_result = ishftc([3,3], [2,2], [-3,3])
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=-3 count for ishftc is not positive
   array_result = ishftc([3,3], [2,2], [-3,-3])
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=-3 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr1)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr2)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr3)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr4)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr5)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr6)
-  !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
+  !ERROR: SIZE=0 count for ishftc is not positive
   array_result = ishftc([3,3], [-2,-2], const_arr7)
 
 end program test_ishftc


        


More information about the flang-commits mailing list