[flang-commits] [flang] cc1d13f - [flang] Fold MAXLOC and MINLOC

peter klausler via flang-commits flang-commits at lists.llvm.org
Tue Oct 5 11:22:10 PDT 2021


Author: peter klausler
Date: 2021-10-05T11:22:02-07:00
New Revision: cc1d13f997f6db6f2e6c209b9449695b91a68e32

URL: https://github.com/llvm/llvm-project/commit/cc1d13f997f6db6f2e6c209b9449695b91a68e32
DIFF: https://github.com/llvm/llvm-project/commit/cc1d13f997f6db6f2e6c209b9449695b91a68e32.diff

LOG: [flang] Fold MAXLOC and MINLOC

Generalize the code that folds FINDLOC to also handle
folding for MAXLOC and MINLOC.

Differential Revision: https://reviews.llvm.org/D110951

Added: 
    

Modified: 
    flang/include/flang/Evaluate/type.h
    flang/lib/Evaluate/fold-character.cpp
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/fold-reduction.h
    flang/test/Evaluate/folding30.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index a57d8107b7e33..23fc6cab809d8 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -335,8 +335,10 @@ using LogicalTypes = CategoryTypes<TypeCategory::Logical>;
 
 using FloatingTypes = common::CombineTuples<RealTypes, ComplexTypes>;
 using NumericTypes = common::CombineTuples<IntegerTypes, FloatingTypes>;
-using RelationalTypes = common::CombineTuples<NumericTypes, CharacterTypes>;
-using AllIntrinsicTypes = common::CombineTuples<RelationalTypes, LogicalTypes>;
+using RelationalTypes =
+    common::CombineTuples<IntegerTypes, RealTypes, CharacterTypes>;
+using AllIntrinsicTypes =
+    common::CombineTuples<NumericTypes, CharacterTypes, LogicalTypes>;
 using LengthlessIntrinsicTypes =
     common::CombineTuples<NumericTypes, LogicalTypes>;
 

diff  --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 7c37e4cb011fb..72e124a6a84b2 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -102,7 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
           CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
     }
   }
-  // TODO: maxloc, minloc, transfer
+  // TODO: transfer
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 441a5a9924f98..86fb46ba6094e 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -195,42 +195,58 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
   return Expr<T>{std::move(ref)};
 }
 
-// FINDLOC()
-class FindlocHelper {
+// FINDLOC(), MAXLOC(), & MINLOC()
+enum class WhichLocation { Findloc, Maxloc, Minloc };
+template <WhichLocation WHICH> class LocationHelper {
 public:
-  FindlocHelper(
+  LocationHelper(
       DynamicType &&type, ActualArguments &arg, FoldingContext &context)
       : type_{type}, arg_{arg}, context_{context} {}
   using Result = std::optional<Constant<SubscriptInteger>>;
-  using Types = AllIntrinsicTypes;
+  using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
+      AllIntrinsicTypes, RelationalTypes>;
 
   template <typename T> Result Test() const {
     if (T::category != type_.category() || T::kind != type_.kind()) {
       return std::nullopt;
     }
-    CHECK(arg_.size() == 6);
+    CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
     Folder<T> folder{context_};
     Constant<T> *array{folder.Folding(arg_[0])};
-    Constant<T> *value{folder.Folding(arg_[1])};
-    if (!array || !value) {
+    if (!array) {
       return std::nullopt;
     }
+    std::optional<Constant<T>> value;
+    if constexpr (WHICH == WhichLocation::Findloc) {
+      if (const Constant<T> *p{folder.Folding(arg_[1])}) {
+        value.emplace(*p);
+      } else {
+        return std::nullopt;
+      }
+    }
     std::optional<int> dim;
     Constant<LogicalResult> *mask{
-        GetReductionMASK(arg_[3], array->shape(), context_)};
-    if ((!mask && arg_[3]) ||
-        !CheckReductionDIM(dim, context_, arg_, 2, array->Rank())) {
+        GetReductionMASK(arg_[maskArg], array->shape(), context_)};
+    if ((!mask && arg_[maskArg]) ||
+        !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
       return std::nullopt;
     }
     bool back{false};
-    if (arg_[5]) {
-      const auto *backConst{Folder<LogicalResult>{context_}.Folding(arg_[5])};
+    if (arg_[backArg]) {
+      const auto *backConst{
+          Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
       if (backConst) {
         back = backConst->GetScalarValue().value().IsTrue();
       } else {
         return std::nullopt;
       }
     }
+    const RelationalOperator relation{WHICH == WhichLocation::Findloc
+            ? RelationalOperator::EQ
+            : WHICH == WhichLocation::Maxloc
+            ? (back ? RelationalOperator::GE : RelationalOperator::GT)
+            : back ? RelationalOperator::LE
+                   : RelationalOperator::LT};
     // Use lower bounds of 1 exclusively.
     array->SetLowerBoundsToOne();
     ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
@@ -252,10 +268,11 @@ class FindlocHelper {
       ConstantSubscript n{GetSize(resultShape)};
       for (ConstantSubscript j{0}; j < n; ++j) {
         ConstantSubscript hit{array->lbounds()[zbDim] - 1};
+        value.reset();
         for (ConstantSubscript k{0}; k < dimLength;
              ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
           if ((!mask || mask->At(maskAt).IsTrue()) &&
-              IsHit(array->At(at), *value)) {
+              IsHit(array->At(at), value, relation)) {
             hit = at[zbDim];
             if (!back) {
               break;
@@ -279,7 +296,7 @@ class FindlocHelper {
       for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
            mask && mask->IncrementSubscripts(maskAt)) {
         if ((!mask || mask->At(maskAt).IsTrue()) &&
-            IsHit(array->At(at), *value)) {
+            IsHit(array->At(at), value, relation)) {
           resultIndices = at;
           if (!back) {
             break;
@@ -297,43 +314,57 @@ class FindlocHelper {
 
 private:
   template <typename T>
-  bool IsHit(typename Constant<T>::Element element, Constant<T> value) const {
+  bool IsHit(typename Constant<T>::Element element,
+      std::optional<Constant<T>> &value,
+      [[maybe_unused]] RelationalOperator relation) const {
     std::optional<Expr<LogicalResult>> cmp;
-    if constexpr (T::category == TypeCategory::Logical) {
-      // array(at) .EQV. value?
-      cmp.emplace(
-          ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
-              LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
-              Expr<T>{std::move(value)}}}));
-    } else { // array(at) .EQ. value?
-      cmp.emplace(PackageRelation(RelationalOperator::EQ,
-          Expr<T>{Constant<T>{std::move(element)}}, Expr<T>{std::move(value)}));
-    }
-    Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
-    return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
+    if (value) {
+      if constexpr (T::category == TypeCategory::Logical) {
+        // array(at) .EQV. value?
+        static_assert(WHICH == WhichLocation::Findloc);
+        cmp.emplace(
+            ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
+                LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
+                Expr<T>{Constant<T>{*value}}}}));
+      } else { // compare array(at) to value
+        cmp.emplace(
+            PackageRelation(relation, Expr<T>{Constant<T>{std::move(element)}},
+                Expr<T>{Constant<T>{*value}}));
+      }
+      Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
+      return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
+    } else { // first unmasked element seen for MAXLOC/MINLOC
+      value.emplace(std::move(element));
+      return true;
+    }
   }
 
+  static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
+  static constexpr int maskArg{dimArg + 1};
+  static constexpr int backArg{maskArg + 2};
+
   DynamicType type_;
   ActualArguments &arg_;
   FoldingContext &context_;
 };
 
-static std::optional<Constant<SubscriptInteger>> FoldFindlocCall(
+template <WhichLocation which>
+static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
     ActualArguments &arg, FoldingContext &context) {
-  CHECK(arg.size() == 6);
   if (arg[0]) {
     if (auto type{arg[0]->GetType()}) {
-      return common::SearchTypes(FindlocHelper{std::move(*type), arg, context});
+      return common::SearchTypes(
+          LocationHelper<which>{std::move(*type), arg, context});
     }
   }
   return std::nullopt;
 }
 
-template <typename T>
-static Expr<T> FoldFindloc(FoldingContext &context, FunctionRef<T> &&ref) {
+template <WhichLocation which, typename T>
+static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
   static_assert(T::category == TypeCategory::Integer);
   if (std::optional<Constant<SubscriptInteger>> found{
-          FoldFindlocCall(ref.arguments(), context)}) {
+          FoldLocationCall<which>(ref.arguments(), context)}) {
     return Expr<T>{Fold(
         context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
   } else {
@@ -451,7 +482,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
       DIE("exponent argument must be real");
     }
   } else if (name == "findloc") {
-    return FoldFindloc<T>(context, std::move(funcRef));
+    return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
   } else if (name == "huge") {
     return Expr<T>{Scalar<T>::HUGE()};
   } else if (name == "iachar" || name == "ichar") {
@@ -661,6 +692,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
           },
           sx->u);
     }
+  } else if (name == "maxloc") {
+    return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
   } else if (name == "maxval") {
     return FoldMaxvalMinval<T>(context, std::move(funcRef),
         RelationalOperator::GT, T::Scalar::Least());
@@ -669,6 +702,10 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "merge_bits") {
     return FoldElementalIntrinsic<T, T, T, T>(
         context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
+  } else if (name == "min") {
+    return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
+  } else if (name == "min0" || name == "min1") {
+    return RewriteSpecificMINorMAX(context, std::move(funcRef));
   } else if (name == "minexponent") {
     if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
       return std::visit(
@@ -678,10 +715,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
           },
           sx->u);
     }
-  } else if (name == "min") {
-    return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
-  } else if (name == "min0" || name == "min1") {
-    return RewriteSpecificMINorMAX(context, std::move(funcRef));
+  } else if (name == "minloc") {
+    return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
   } else if (name == "minval") {
     return FoldMaxvalMinval<T>(
         context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
@@ -853,8 +888,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "ubound") {
     return UBOUND(context, std::move(funcRef));
   }
-  // TODO: dot_product, ibits, image_status, ishftc,
-  // matmul, maxloc, minloc, sign, transfer
+  // TODO: dot_product, ibits, ishftc, matmul, sign, transfer
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index bffd8ea757555..93b3ce20a75b5 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -143,7 +143,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
     return Expr<T>{Scalar<T>::TINY()};
   }
   // TODO: dim, dot_product, fraction, matmul,
-  // maxloc, minloc, modulo, nearest, norm2, rrspacing, scale,
+  // modulo, nearest, norm2, rrspacing, scale,
   // set_exponent, spacing, transfer,
   // bessel_jn (transformational) and bessel_yn (transformational)
   return Expr<T>{std::move(funcRef)};

diff  --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index bfa4a1f80c79b..c7065c2c6bf62 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-// TODO: DOT_PRODUCT, NORM2, MAXLOC, MINLOC, PARITY
+// TODO: DOT_PRODUCT, NORM2, PARITY
 
 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_

diff  --git a/flang/test/Evaluate/folding30.f90 b/flang/test/Evaluate/folding30.f90
index 748723c08e88d..d0bbfde5ee480 100644
--- a/flang/test/Evaluate/folding30.f90
+++ b/flang/test/Evaluate/folding30.f90
@@ -1,21 +1,57 @@
 ! RUN: %python %S/test_folding.py %s %flang_fc1
-! Tests folding of FINDLOC
+! Tests folding of FINDLOC, MAXLOC, & MINLOC
 module m1
   integer, parameter :: ia1(2:6) = [1, 2, 3, 2, 1]
   integer, parameter :: ia2(2:3,2:4) = reshape([1, 2, 3, 3, 2, 1], shape(ia2))
 
-  logical, parameter :: ti1a = all(findloc(ia1, 1) == 1)
-  logical, parameter :: ti1ar = rank(findloc(ia1, 1)) == 1
-  logical, parameter :: ti1ak = kind(findloc(ia1, 1, kind=2)) == 2
-  logical, parameter :: ti1ad = findloc(ia1, 1, dim=1) == 1
-  logical, parameter :: ti1adr = rank(findloc(ia1, 1, dim=1)) == 0
-  logical, parameter :: ti1b = all(findloc(ia1, 1, back=.true.) == 5)
-  logical, parameter :: ti1c = all(findloc(ia1, 2, mask=[.false., .false., .true., .true., .true.]) == 4)
+  logical, parameter :: fi1a = all(findloc(ia1, 1) == 1)
+  logical, parameter :: fi1ar = rank(findloc(ia1, 1)) == 1
+  logical, parameter :: fi1ak = kind(findloc(ia1, 1, kind=2)) == 2
+  logical, parameter :: fi1ad = findloc(ia1, 1, dim=1) == 1
+  logical, parameter :: fi1adr = rank(findloc(ia1, 1, dim=1)) == 0
+  logical, parameter :: fi1b = all(findloc(ia1, 1, back=.true.) == 5)
+  logical, parameter :: fi1c = all(findloc(ia1, 2, mask=[.false., .false., .true., .true., .true.]) == 4)
 
-  logical, parameter :: ti2a = all(findloc(ia2, 1) == [1, 1])
-  logical, parameter :: ti2ar = rank(findloc(ia2, 1)) == 1
-  logical, parameter :: ti2b = all(findloc(ia2, 1, back=.true.) == [2, 3])
-  logical, parameter :: ti2c = all(findloc(ia2, 2, mask=reshape([.false., .false., .true., .true., .true., .false.], shape(ia2))) == [1, 3])
-  logical, parameter :: ti2d = all(findloc(ia2, 1, dim=1) == [1, 0, 2])
-  logical, parameter :: ti2e = all(findloc(ia2, 1, dim=2) == [1, 3])
+  logical, parameter :: fi2a = all(findloc(ia2, 1) == [1, 1])
+  logical, parameter :: fi2ar = rank(findloc(ia2, 1)) == 1
+  logical, parameter :: fi2b = all(findloc(ia2, 1, back=.true.) == [2, 3])
+  logical, parameter :: fi2c = all(findloc(ia2, 2, mask=reshape([.false., .false., .true., .true., .true., .false.], shape(ia2))) == [1, 3])
+  logical, parameter :: fi2d = all(findloc(ia2, 1, dim=1) == [1, 0, 2])
+  logical, parameter :: fi2e = all(findloc(ia2, 1, dim=2) == [1, 3])
+
+  logical, parameter :: xi1a = all(maxloc(ia1) == 3)
+  logical, parameter :: xi1ar = rank(maxloc(ia1)) == 1
+  logical, parameter :: xi1ak = kind(maxloc(ia1, kind=2)) == 2
+  logical, parameter :: xi1ad = maxloc(ia1, dim=1) == 3
+  logical, parameter :: xi1adr = rank(maxloc(ia1, dim=1)) == 0
+  logical, parameter :: xi1b = all(maxloc(ia1, back=.true.) == 3)
+  logical, parameter :: xi1c = all(maxloc(ia1, mask=[.false., .true., .false., .true., .true.]) == 2)
+  logical, parameter :: xi1d = all(maxloc(ia1, mask=[.false., .true., .false., .true., .true.], back=.true.) == 4)
+
+  logical, parameter :: xi2a = all(maxloc(ia2) == [1, 2])
+  logical, parameter :: xi2ar = rank(maxloc(ia2)) == 1
+  logical, parameter :: xi2b = all(maxloc(ia2, back=.true.) == [2, 2])
+  logical, parameter :: xi2c = all(maxloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .true.], shape(ia2))) == [2, 1])
+  logical, parameter :: xi2d = all(maxloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .true.], shape(ia2)), back=.true.) == [1, 3])
+  logical, parameter :: xi2e = all(maxloc(ia2, dim=1) == [2, 1, 1])
+  logical, parameter :: xi2f = all(maxloc(ia2, dim=1, back=.true.) == [2, 2, 1])
+  logical, parameter :: xi2g = all(maxloc(ia2, dim=2) == [2, 2])
+
+  logical, parameter :: ni1a = all(minloc(ia1) == 1)
+  logical, parameter :: ni1ar = rank(minloc(ia1)) == 1
+  logical, parameter :: ni1ak = kind(minloc(ia1, kind=2)) == 2
+  logical, parameter :: ni1ad = minloc(ia1, dim=1) == 1
+  logical, parameter :: ni1adr = rank(minloc(ia1, dim=1)) == 0
+  logical, parameter :: ni1b = all(minloc(ia1, back=.true.) == 5)
+  logical, parameter :: ni1c = all(minloc(ia1, mask=[.false., .true., .true., .true., .false.]) == 2)
+  logical, parameter :: ni1d = all(minloc(ia1, mask=[.false., .true., .true., .true., .false.], back=.true.) == 4)
+
+  logical, parameter :: ni2a = all(minloc(ia2) == [1, 1])
+  logical, parameter :: ni2ar = rank(minloc(ia2)) == 1
+  logical, parameter :: ni2b = all(minloc(ia2, back=.true.) == [2, 3])
+  logical, parameter :: ni2c = all(minloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .false.], shape(ia2))) == [2, 1])
+  logical, parameter :: ni2d = all(minloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .false.], shape(ia2)), back=.true.) == [1, 3])
+  logical, parameter :: ni2e = all(minloc(ia2, dim=1) == [1, 1, 2])
+  logical, parameter :: ni2f = all(minloc(ia2, dim=1, back=.true.) == [1, 2, 2])
+  logical, parameter :: ni2g = all(minloc(ia2, dim=2) == [1, 3])
 end module


        


More information about the flang-commits mailing list