[flang-commits] [flang] [flang][NFC] Move new code to right place (PR #144551)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Jun 17 08:52:57 PDT 2025


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/144551

Some new code was added to flang/Semantics that only depends on facilities in flang/Evaluate.  Move it into Evaluate and clean up some minor stylistic problems.

>From ef217ad4243945319a91bd669768f86f661177dc Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 17 Jun 2025 08:49:52 -0700
Subject: [PATCH] [flang][NFC] Move new code to right place

Some new code was added to flang/Semantics that only depends on
facilities in flang/Evaluate.  Move it into Evaluate and clean up
some minor stylistic problems.
---
 flang/include/flang/Evaluate/tools.h        | 148 +++++++++
 flang/include/flang/Semantics/tools.h       | 149 ---------
 flang/lib/Evaluate/tools.cpp                | 311 ++++++++++++++++++
 flang/lib/Lower/OpenACC.cpp                 |   2 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp           |   7 +-
 flang/lib/Semantics/check-omp-structure.cpp |   3 +
 flang/lib/Semantics/tools.cpp               | 329 --------------------
 7 files changed, 467 insertions(+), 482 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 1959d5f3a5899..e04621f71f9a7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
 }
 
+// Checks whether the symbol on the LHS is present in the RHS expression.
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
+
+namespace operation {
+
+enum class Operator {
+  Unknown,
+  Add,
+  And,
+  Associated,
+  Call,
+  Constant,
+  Convert,
+  Div,
+  Eq,
+  Eqv,
+  False,
+  Ge,
+  Gt,
+  Identity,
+  Intrinsic,
+  Le,
+  Lt,
+  Max,
+  Min,
+  Mul,
+  Ne,
+  Neqv,
+  Not,
+  Or,
+  Pow,
+  Resize, // Convert within the same TypeCategory
+  Sub,
+  True,
+};
+
+std::string ToString(Operator op);
+
+template <typename... Ts, int Kind>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
+  switch (op.derived().logicalOperator) {
+  case common::LogicalOperator::And:
+    return Operator::And;
+  case common::LogicalOperator::Or:
+    return Operator::Or;
+  case common::LogicalOperator::Eqv:
+    return Operator::Eqv;
+  case common::LogicalOperator::Neqv:
+    return Operator::Neqv;
+  case common::LogicalOperator::Not:
+    return Operator::Not;
+  }
+  return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
+  switch (op.derived().opr) {
+  case common::RelationalOperator::LT:
+    return Operator::Lt;
+  case common::RelationalOperator::LE:
+    return Operator::Le;
+  case common::RelationalOperator::EQ:
+    return Operator::Eq;
+  case common::RelationalOperator::NE:
+    return Operator::Ne;
+  case common::RelationalOperator::GE:
+    return Operator::Ge;
+  case common::RelationalOperator::GT:
+    return Operator::Gt;
+  }
+  return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+  return Operator::Add;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+  return Operator::Sub;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+  return Operator::Mul;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+  return Operator::Div;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+  return Operator::Pow;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+  return Operator::Pow;
+}
+
+template <typename T, common::TypeCategory C, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+  if constexpr (C == T::category) {
+    return Operator::Resize;
+  } else {
+    return Operator::Convert;
+  }
+}
+
+template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+  return Operator::Constant;
+}
+
+template <typename T> Operator OperationCode(const T &) {
+  return Operator::Unknown;
+}
+
+Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+
+} // namespace operation
+
+// Return information about the top-level operation (ignoring parentheses):
+// the operation code and the list of arguments.
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr);
+
+// Check if expr is same as x, or a sequence of Convert operations on x.
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
+
+// Strip away any top-level Convert operations (if any exist) and return
+// the input value. A ComplexConstructor(x, 0) is also considered as a
+// convert operation.
+// If the input is not Operation, Designator, FunctionRef or Constant,
+// it returns std::nullopt.
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
+
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index 69375a83dec25..f3cfa9b99fb4d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
 // Check for ambiguous USE associations
 bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
 
-// Checks whether the symbol on the LHS is present in the RHS expression.
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
-
-namespace operation {
-
-enum class Operator {
-  Unknown,
-  Add,
-  And,
-  Associated,
-  Call,
-  Constant,
-  Convert,
-  Div,
-  Eq,
-  Eqv,
-  False,
-  Ge,
-  Gt,
-  Identity,
-  Intrinsic,
-  Le,
-  Lt,
-  Max,
-  Min,
-  Mul,
-  Ne,
-  Neqv,
-  Not,
-  Or,
-  Pow,
-  Resize, // Convert within the same TypeCategory
-  Sub,
-  True,
-};
-
-std::string ToString(Operator op);
-
-template <typename... Ts, int Kind>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
-  switch (op.derived().logicalOperator) {
-  case common::LogicalOperator::And:
-    return Operator::And;
-  case common::LogicalOperator::Or:
-    return Operator::Or;
-  case common::LogicalOperator::Eqv:
-    return Operator::Eqv;
-  case common::LogicalOperator::Neqv:
-    return Operator::Neqv;
-  case common::LogicalOperator::Not:
-    return Operator::Not;
-  }
-  return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
-  switch (op.derived().opr) {
-  case common::RelationalOperator::LT:
-    return Operator::Lt;
-  case common::RelationalOperator::LE:
-    return Operator::Le;
-  case common::RelationalOperator::EQ:
-    return Operator::Eq;
-  case common::RelationalOperator::NE:
-    return Operator::Ne;
-  case common::RelationalOperator::GE:
-    return Operator::Ge;
-  case common::RelationalOperator::GT:
-    return Operator::Gt;
-  }
-  return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
-  return Operator::Add;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
-  return Operator::Sub;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
-  return Operator::Mul;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
-  return Operator::Div;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
-  return Operator::Pow;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
-  return Operator::Pow;
-}
-
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
-  if constexpr (C == T::category) {
-    return Operator::Resize;
-  } else {
-    return Operator::Convert;
-  }
-}
-
-template <typename T> //
-Operator OperationCode(const evaluate::Constant<T> &x) {
-  return Operator::Constant;
-}
-
-template <typename T> //
-Operator OperationCode(const T &) {
-  return Operator::Unknown;
-}
-
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
-
-} // namespace operation
-
-/// Return information about the top-level operation (ignoring parentheses):
-/// the operation code and the list of arguments.
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
-    const SomeExpr &expr);
-
-/// Check if expr is same as x, or a sequence of Convert operations on x.
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
-
-/// Strip away any top-level Convert operations (if any exist) and return
-/// the input value. A ComplexConstructor(x, 0) is also considered as a
-/// convert operation.
-/// If the input is not Operation, Designator, FunctionRef or Constant,
-/// it returns std::nullopt.
-MaybeExpr GetConvertInput(const SomeExpr &x);
 } // namespace Fortran::semantics
 #endif // FORTRAN_SEMANTICS_TOOLS_H_
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 222c32a9c332e..68838564f87ba 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -13,6 +13,7 @@
 #include "flang/Evaluate/traverse.h"
 #include "flang/Parser/message.h"
 #include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSwitch.h"
 #include <algorithm>
 #include <variant>
 
@@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
   }
 }
 
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
+  if (lhs && rhs) {
+    if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
+      const Symbol &first{*lhsSymbols.front()};
+      for (const Symbol &symbol : GetSymbolVector(*rhs)) {
+        if (first == symbol) {
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
+namespace operation {
+template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
+  return AsGenericExpr(common::Clone(x));
+}
+
+template <bool IgnoreResizingConverts>
+struct ArgumentExtractor
+    : public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+          std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
+  using Arguments = std::vector<Expr<SomeType>>;
+  using Result = std::pair<operation::Operator, Arguments>;
+  using Base =
+      Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
+  static constexpr auto IgnoreResizes{IgnoreResizingConverts};
+  static constexpr auto Logical{common::TypeCategory::Logical};
+  ArgumentExtractor() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <int Kind>
+  Result operator()(const Constant<Type<Logical, Kind>> &x) const {
+    if (const auto &val{x.GetScalarValue()}) {
+      return val->IsTrue()
+          ? std::make_pair(operation::Operator::True, Arguments{})
+          : std::make_pair(operation::Operator::False, Arguments{});
+    }
+    return Default();
+  }
+
+  template <typename R> Result operator()(const FunctionRef<R> &x) const {
+    Result result{operation::OperationCode(x.proc()), {}};
+    for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+      if (auto *e{x.UnwrapArgExpr(i)}) {
+        result.second.push_back(*e);
+      }
+    }
+    return result;
+  }
+
+  template <typename D, typename R, typename... Os>
+  Result operator()(const Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, Parentheses<R>>) {
+      // Ignore top-level parentheses.
+      return (*this)(x.template operand<0>());
+    }
+    if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
+      // Ignore conversions within the same category.
+      // Atomic operations on int(kind=1) may be implicitly widened
+      // to int(kind=4) for example.
+      return (*this)(x.template operand<0>());
+    } else {
+      return std::make_pair(operation::OperationCode(x),
+          OperationArgs(x, std::index_sequence_for<Os...>{}));
+    }
+  }
+
+  template <typename T> Result operator()(const Designator<T> &x) const {
+    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+  }
+
+  template <typename T> Result operator()(const Constant<T> &x) const {
+    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+  }
+
+  template <typename... Rs>
+  Result Combine(Result &&result, Rs &&...results) const {
+    // There shouldn't be any combining needed, since we're stopping the
+    // traversal at the top-level operation, but implement one that picks
+    // the first non-empty result.
+    if constexpr (sizeof...(Rs) == 0) {
+      return std::move(result);
+    } else {
+      if (!result.second.empty()) {
+        return std::move(result);
+      } else {
+        return Combine(std::move(results)...);
+      }
+    }
+  }
+
+private:
+  template <typename D, typename R, typename... Os, size_t... Is>
+  Arguments OperationArgs(
+      const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
+    return Arguments{Expr<SomeType>(x.template operand<Is>())...};
+  }
+};
+} // namespace operation
+
+std::string operation::ToString(operation::Operator op) {
+  switch (op) {
+  case Operator::Unknown:
+    return "??";
+  case Operator::Add:
+    return "+";
+  case Operator::And:
+    return "AND";
+  case Operator::Associated:
+    return "ASSOCIATED";
+  case Operator::Call:
+    return "function-call";
+  case Operator::Constant:
+    return "constant";
+  case Operator::Convert:
+    return "type-conversion";
+  case Operator::Div:
+    return "/";
+  case Operator::Eq:
+    return "==";
+  case Operator::Eqv:
+    return "EQV";
+  case Operator::False:
+    return ".FALSE.";
+  case Operator::Ge:
+    return ">=";
+  case Operator::Gt:
+    return ">";
+  case Operator::Identity:
+    return "identity";
+  case Operator::Intrinsic:
+    return "intrinsic";
+  case Operator::Le:
+    return "<=";
+  case Operator::Lt:
+    return "<";
+  case Operator::Max:
+    return "MAX";
+  case Operator::Min:
+    return "MIN";
+  case Operator::Mul:
+    return "*";
+  case Operator::Ne:
+    return "/=";
+  case Operator::Neqv:
+    return "NEQV/EOR";
+  case Operator::Not:
+    return "NOT";
+  case Operator::Or:
+    return "OR";
+  case Operator::Pow:
+    return "**";
+  case Operator::Resize:
+    return "resize";
+  case Operator::Sub:
+    return "-";
+  case Operator::True:
+    return ".TRUE.";
+  }
+  llvm_unreachable("Unhandler operator");
+}
+
+operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
+  Operator code{llvm::StringSwitch<Operator>(proc.GetName())
+          .Case("associated", Operator::Associated)
+          .Case("min", Operator::Min)
+          .Case("max", Operator::Max)
+          .Case("iand", Operator::And)
+          .Case("ior", Operator::Or)
+          .Case("ieor", Operator::Neqv)
+          .Default(Operator::Call)};
+  if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+    return Operator::Intrinsic;
+  }
+  return code;
+}
+
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr) {
+  return operation::ArgumentExtractor<true>{}(expr);
+}
+
+namespace operation {
+struct ConvertCollector
+    : public Traverse<ConvertCollector,
+          std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
+          false> {
+  using Result =
+      std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
+  using Base = Traverse<ConvertCollector, Result, false>;
+  ConvertCollector() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <typename T> Result operator()(const Designator<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename T> Result operator()(const FunctionRef<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename T> Result operator()(const Constant<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename D, typename R, typename... Os>
+  Result operator()(const Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, Parentheses<R>>) {
+      // Ignore parentheses.
+      return (*this)(x.template operand<0>());
+    } else if constexpr (is_convert_v<D>) {
+      // Convert should always have a typed result, so it should be safe to
+      // dereference x.GetType().
+      return Combine(
+          {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+    } else if constexpr (is_complex_constructor_v<D>) {
+      // This is a conversion iff the imaginary operand is 0.
+      if (IsZero(x.template operand<1>())) {
+        return Combine(
+            {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+      } else {
+        return {AsSomeExpr(x.derived()), {}};
+      }
+    } else {
+      return {AsSomeExpr(x.derived()), {}};
+    }
+  }
+
+  template <typename... Rs>
+  Result Combine(Result &&result, Rs &&...results) const {
+    Result v(std::move(result));
+    auto setValue{[](std::optional<Expr<SomeType>> &x,
+                      std::optional<Expr<SomeType>> &&y) {
+      assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+      if (!x.has_value()) {
+        x = std::move(y);
+      }
+    }};
+    auto moveAppend{[](auto &accum, auto &&other) {
+      for (auto &&s : other) {
+        accum.push_back(std::move(s));
+      }
+    }};
+    (setValue(v.first, std::move(results).first), ...);
+    (moveAppend(v.second, std::move(results).second), ...);
+    return v;
+  }
+
+private:
+  template <typename A> static bool IsZero(const A &x) { return false; }
+  template <typename T> static bool IsZero(const Expr<T> &x) {
+    return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+  }
+  template <typename T> static bool IsZero(const Constant<T> &x) {
+    if (auto &&maybeScalar{x.GetScalarValue()}) {
+      return maybeScalar->IsZero();
+    } else {
+      return false;
+    }
+  }
+
+  template <typename T> struct is_convert {
+    static constexpr bool value{false};
+  };
+  template <typename T, common::TypeCategory C>
+  struct is_convert<Convert<T, C>> {
+    static constexpr bool value{true};
+  };
+  template <int K> struct is_convert<ComplexComponent<K>> {
+    // Conversion from complex to real.
+    static constexpr bool value{true};
+  };
+  template <typename T>
+  static constexpr bool is_convert_v{is_convert<T>::value};
+
+  template <typename T> struct is_complex_constructor {
+    static constexpr bool value{false};
+  };
+  template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
+    static constexpr bool value{true};
+  };
+  template <typename T>
+  static constexpr bool is_complex_constructor_v{
+      is_complex_constructor<T>::value};
+};
+} // namespace operation
+
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
+  // This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
+  return operation::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
+  // Check if expr is same as x, or a sequence of Convert operations on x.
+  if (expr == x) {
+    return true;
+  } else if (auto maybe{GetConvertInput(expr)}) {
+    return *maybe == x;
+  } else {
+    return false;
+  }
+}
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 69e9c53baa740..3ef3330cba2d6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Block &block = atomicCaptureOp->getRegion(0).back();
   firOpBuilder.setInsertionPointToStart(&block);
   if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
-    if (Fortran::semantics::CheckForSymbolMatch(
+    if (Fortran::evaluate::CheckForSymbolMatch(
             Fortran::semantics::GetExpr(stmt2Var),
             Fortran::semantics::GetExpr(stmt2Expr))) {
       // Atomic capture construct is of the form [capture-stmt, update-stmt]
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 82673f0948a5b..0acfd5b0a2534 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc,
   mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
 
   // This must exist by now.
-  SomeExpr input = *semantics::GetConvertInput(assign.rhs);
-  std::vector<SomeExpr> args{semantics::GetTopLevelOperation(input).second};
+  SomeExpr input = *Fortran::evaluate::GetConvertInput(assign.rhs);
+  std::vector<SomeExpr> args{
+      Fortran::evaluate::GetTopLevelOperation(input).second};
   assert(!args.empty() && "Update operation without arguments");
   for (auto &arg : args) {
-    if (!semantics::IsSameOrConvertOf(arg, atom)) {
+    if (!Fortran::evaluate::IsSameOrConvertOf(arg, atom)) {
       mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
       overrides.try_emplace(&arg, val);
     }
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 58d28dce7094a..47bd4e8ffd43c 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -12,6 +12,7 @@
 #include "flang/Evaluate/check-expression.h"
 #include "flang/Evaluate/expression.h"
 #include "flang/Evaluate/shape.h"
+#include "flang/Evaluate/tools.h"
 #include "flang/Evaluate/type.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
@@ -2962,6 +2963,8 @@ static bool IsPointerAssignment(const evaluate::Assignment &x) {
       std::holds_alternative<evaluate::Assignment::BoundsRemapping>(x.u);
 }
 
+namespace operation = Fortran::evaluate::operation;
+
 static bool IsCheckForAssociated(const SomeExpr &cond) {
   return GetTopLevelOperation(cond).first == operation::Operator::Associated;
 }
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index bf520d04a50cc..d053179448c00 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -17,7 +17,6 @@
 #include "flang/Semantics/tools.h"
 #include "flang/Semantics/type.h"
 #include "flang/Support/Fortran.h"
-#include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <set>
@@ -1789,332 +1788,4 @@ bool HadUseError(
   }
 }
 
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs) {
-  if (lhs && rhs) {
-    if (SymbolVector lhsSymbols{evaluate::GetSymbolVector(*lhs)};
-        !lhsSymbols.empty()) {
-      const Symbol &first{*lhsSymbols.front()};
-      for (const Symbol &symbol : evaluate::GetSymbolVector(*rhs)) {
-        if (first == symbol) {
-          return true;
-        }
-      }
-    }
-  }
-  return false;
-}
-
-namespace operation {
-template <typename T> //
-SomeExpr asSomeExpr(const T &x) {
-  auto copy{x};
-  return AsGenericExpr(std::move(copy));
-}
-
-template <bool IgnoreResizingConverts> //
-struct ArgumentExtractor
-    : public evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
-          std::pair<operation::Operator, std::vector<SomeExpr>>, false> {
-  using Arguments = std::vector<SomeExpr>;
-  using Result = std::pair<operation::Operator, Arguments>;
-  using Base = evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
-      Result, false>;
-  static constexpr auto IgnoreResizes = IgnoreResizingConverts;
-  static constexpr auto Logical = common::TypeCategory::Logical;
-  ArgumentExtractor() : Base(*this) {}
-
-  Result Default() const { return {}; }
-
-  using Base::operator();
-
-  template <int Kind> //
-  Result operator()(
-      const evaluate::Constant<evaluate::Type<Logical, Kind>> &x) const {
-    if (const auto &val{x.GetScalarValue()}) {
-      return val->IsTrue()
-          ? std::make_pair(operation::Operator::True, Arguments{})
-          : std::make_pair(operation::Operator::False, Arguments{});
-    }
-    return Default();
-  }
-
-  template <typename R> //
-  Result operator()(const evaluate::FunctionRef<R> &x) const {
-    Result result{operation::OperationCode(x.proc()), {}};
-    for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
-      if (auto *e{x.UnwrapArgExpr(i)}) {
-        result.second.push_back(*e);
-      }
-    }
-    return result;
-  }
-
-  template <typename D, typename R, typename... Os>
-  Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
-    if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
-      // Ignore top-level parentheses.
-      return (*this)(x.template operand<0>());
-    }
-    if constexpr (IgnoreResizes &&
-        std::is_same_v<D, evaluate::Convert<R, R::category>>) {
-      // Ignore conversions within the same category.
-      // Atomic operations on int(kind=1) may be implicitly widened
-      // to int(kind=4) for example.
-      return (*this)(x.template operand<0>());
-    } else {
-      return std::make_pair(operation::OperationCode(x),
-          OperationArgs(x, std::index_sequence_for<Os...>{}));
-    }
-  }
-
-  template <typename T> //
-  Result operator()(const evaluate::Designator<T> &x) const {
-    return {operation::Operator::Identity, {asSomeExpr(x)}};
-  }
-
-  template <typename T> //
-  Result operator()(const evaluate::Constant<T> &x) const {
-    return {operation::Operator::Identity, {asSomeExpr(x)}};
-  }
-
-  template <typename... Rs> //
-  Result Combine(Result &&result, Rs &&...results) const {
-    // There shouldn't be any combining needed, since we're stopping the
-    // traversal at the top-level operation, but implement one that picks
-    // the first non-empty result.
-    if constexpr (sizeof...(Rs) == 0) {
-      return std::move(result);
-    } else {
-      if (!result.second.empty()) {
-        return std::move(result);
-      } else {
-        return Combine(std::move(results)...);
-      }
-    }
-  }
-
-private:
-  template <typename D, typename R, typename... Os, size_t... Is>
-  Arguments OperationArgs(const evaluate::Operation<D, R, Os...> &x,
-      std::index_sequence<Is...>) const {
-    return Arguments{SomeExpr(x.template operand<Is>())...};
-  }
-};
-} // namespace operation
-
-std::string operation::ToString(operation::Operator op) {
-  switch (op) {
-  case Operator::Unknown:
-    return "??";
-  case Operator::Add:
-    return "+";
-  case Operator::And:
-    return "AND";
-  case Operator::Associated:
-    return "ASSOCIATED";
-  case Operator::Call:
-    return "function-call";
-  case Operator::Constant:
-    return "constant";
-  case Operator::Convert:
-    return "type-conversion";
-  case Operator::Div:
-    return "/";
-  case Operator::Eq:
-    return "==";
-  case Operator::Eqv:
-    return "EQV";
-  case Operator::False:
-    return ".FALSE.";
-  case Operator::Ge:
-    return ">=";
-  case Operator::Gt:
-    return ">";
-  case Operator::Identity:
-    return "identity";
-  case Operator::Intrinsic:
-    return "intrinsic";
-  case Operator::Le:
-    return "<=";
-  case Operator::Lt:
-    return "<";
-  case Operator::Max:
-    return "MAX";
-  case Operator::Min:
-    return "MIN";
-  case Operator::Mul:
-    return "*";
-  case Operator::Ne:
-    return "/=";
-  case Operator::Neqv:
-    return "NEQV/EOR";
-  case Operator::Not:
-    return "NOT";
-  case Operator::Or:
-    return "OR";
-  case Operator::Pow:
-    return "**";
-  case Operator::Resize:
-    return "resize";
-  case Operator::Sub:
-    return "-";
-  case Operator::True:
-    return ".TRUE.";
-  }
-  llvm_unreachable("Unhandler operator");
-}
-
-operation::Operator operation::OperationCode(
-    const evaluate::ProcedureDesignator &proc) {
-  Operator code = llvm::StringSwitch<Operator>(proc.GetName())
-                      .Case("associated", Operator::Associated)
-                      .Case("min", Operator::Min)
-                      .Case("max", Operator::Max)
-                      .Case("iand", Operator::And)
-                      .Case("ior", Operator::Or)
-                      .Case("ieor", Operator::Neqv)
-                      .Default(Operator::Call);
-  if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
-    return Operator::Intrinsic;
-  }
-  return code;
-}
-
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
-    const SomeExpr &expr) {
-  return operation::ArgumentExtractor<true>{}(expr);
-}
-
-namespace operation {
-struct ConvertCollector
-    : public evaluate::Traverse<ConvertCollector,
-          std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>, false> {
-  using Result = std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>;
-  using Base = evaluate::Traverse<ConvertCollector, Result, false>;
-  ConvertCollector() : Base(*this) {}
-
-  Result Default() const { return {}; }
-
-  using Base::operator();
-
-  template <typename T> //
-  Result operator()(const evaluate::Designator<T> &x) const {
-    return {asSomeExpr(x), {}};
-  }
-
-  template <typename T> //
-  Result operator()(const evaluate::FunctionRef<T> &x) const {
-    return {asSomeExpr(x), {}};
-  }
-
-  template <typename T> //
-  Result operator()(const evaluate::Constant<T> &x) const {
-    return {asSomeExpr(x), {}};
-  }
-
-  template <typename D, typename R, typename... Os>
-  Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
-    if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
-      // Ignore parentheses.
-      return (*this)(x.template operand<0>());
-    } else if constexpr (is_convert_v<D>) {
-      // Convert should always have a typed result, so it should be safe to
-      // dereference x.GetType().
-      return Combine(
-          {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
-    } else if constexpr (is_complex_constructor_v<D>) {
-      // This is a conversion iff the imaginary operand is 0.
-      if (IsZero(x.template operand<1>())) {
-        return Combine(
-            {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
-      } else {
-        return {asSomeExpr(x.derived()), {}};
-      }
-    } else {
-      return {asSomeExpr(x.derived()), {}};
-    }
-  }
-
-  template <typename... Rs> //
-  Result Combine(Result &&result, Rs &&...results) const {
-    Result v(std::move(result));
-    auto setValue{[](MaybeExpr &x, MaybeExpr &&y) {
-      assert((!x.has_value() || !y.has_value()) && "Multiple designators");
-      if (!x.has_value()) {
-        x = std::move(y);
-      }
-    }};
-    auto moveAppend{[](auto &accum, auto &&other) {
-      for (auto &&s : other) {
-        accum.push_back(std::move(s));
-      }
-    }};
-    (setValue(v.first, std::move(results).first), ...);
-    (moveAppend(v.second, std::move(results).second), ...);
-    return v;
-  }
-
-private:
-  template <typename T> //
-  static bool IsZero(const T &x) {
-    return false;
-  }
-  template <typename T> //
-  static bool IsZero(const evaluate::Expr<T> &x) {
-    return common::visit([](auto &&s) { return IsZero(s); }, x.u);
-  }
-  template <typename T> //
-  static bool IsZero(const evaluate::Constant<T> &x) {
-    if (auto &&maybeScalar{x.GetScalarValue()}) {
-      return maybeScalar->IsZero();
-    } else {
-      return false;
-    }
-  }
-
-  template <typename T> //
-  struct is_convert {
-    static constexpr bool value{false};
-  };
-  template <typename T, common::TypeCategory C> //
-  struct is_convert<evaluate::Convert<T, C>> {
-    static constexpr bool value{true};
-  };
-  template <int K> //
-  struct is_convert<evaluate::ComplexComponent<K>> {
-    // Conversion from complex to real.
-    static constexpr bool value{true};
-  };
-  template <typename T> //
-  static constexpr bool is_convert_v = is_convert<T>::value;
-
-  template <typename T> //
-  struct is_complex_constructor {
-    static constexpr bool value{false};
-  };
-  template <int K> //
-  struct is_complex_constructor<evaluate::ComplexConstructor<K>> {
-    static constexpr bool value{true};
-  };
-  template <typename T> //
-  static constexpr bool is_complex_constructor_v =
-      is_complex_constructor<T>::value;
-};
-} // namespace operation
-
-MaybeExpr GetConvertInput(const SomeExpr &x) {
-  // This returns SomeExpr(x) when x is a designator/functionref/constant.
-  return operation::ConvertCollector{}(x).first;
-}
-
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) {
-  // Check if expr is same as x, or a sequence of Convert operations on x.
-  if (expr == x) {
-    return true;
-  } else if (auto maybe{GetConvertInput(expr)}) {
-    return *maybe == x;
-  } else {
-    return false;
-  }
-}
 } // namespace Fortran::semantics
\ No newline at end of file



More information about the flang-commits mailing list