[flang-commits] [flang] [flang][NFC] Move new code to right place (PR #144551)

via flang-commits flang-commits at lists.llvm.org
Tue Jun 17 08:53:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

Some new code was added to flang/Semantics that only depends on facilities in flang/Evaluate.  Move it into Evaluate and clean up some minor stylistic problems.

---

Patch is 32.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144551.diff


7 Files Affected:

- (modified) flang/include/flang/Evaluate/tools.h (+148) 
- (modified) flang/include/flang/Semantics/tools.h (-149) 
- (modified) flang/lib/Evaluate/tools.cpp (+311) 
- (modified) flang/lib/Lower/OpenACC.cpp (+1-1) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+4-3) 
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+3) 
- (modified) flang/lib/Semantics/tools.cpp (-329) 


``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 1959d5f3a5899..e04621f71f9a7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
 }
 
+// Checks whether the symbol on the LHS is present in the RHS expression.
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
+
+namespace operation {
+
+enum class Operator {
+  Unknown,
+  Add,
+  And,
+  Associated,
+  Call,
+  Constant,
+  Convert,
+  Div,
+  Eq,
+  Eqv,
+  False,
+  Ge,
+  Gt,
+  Identity,
+  Intrinsic,
+  Le,
+  Lt,
+  Max,
+  Min,
+  Mul,
+  Ne,
+  Neqv,
+  Not,
+  Or,
+  Pow,
+  Resize, // Convert within the same TypeCategory
+  Sub,
+  True,
+};
+
+std::string ToString(Operator op);
+
+template <typename... Ts, int Kind>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
+  switch (op.derived().logicalOperator) {
+  case common::LogicalOperator::And:
+    return Operator::And;
+  case common::LogicalOperator::Or:
+    return Operator::Or;
+  case common::LogicalOperator::Eqv:
+    return Operator::Eqv;
+  case common::LogicalOperator::Neqv:
+    return Operator::Neqv;
+  case common::LogicalOperator::Not:
+    return Operator::Not;
+  }
+  return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
+  switch (op.derived().opr) {
+  case common::RelationalOperator::LT:
+    return Operator::Lt;
+  case common::RelationalOperator::LE:
+    return Operator::Le;
+  case common::RelationalOperator::EQ:
+    return Operator::Eq;
+  case common::RelationalOperator::NE:
+    return Operator::Ne;
+  case common::RelationalOperator::GE:
+    return Operator::Ge;
+  case common::RelationalOperator::GT:
+    return Operator::Gt;
+  }
+  return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+  return Operator::Add;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+  return Operator::Sub;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+  return Operator::Mul;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+  return Operator::Div;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+  return Operator::Pow;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+  return Operator::Pow;
+}
+
+template <typename T, common::TypeCategory C, typename... Ts>
+Operator OperationCode(
+    const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+  if constexpr (C == T::category) {
+    return Operator::Resize;
+  } else {
+    return Operator::Convert;
+  }
+}
+
+template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+  return Operator::Constant;
+}
+
+template <typename T> Operator OperationCode(const T &) {
+  return Operator::Unknown;
+}
+
+Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+
+} // namespace operation
+
+// Return information about the top-level operation (ignoring parentheses):
+// the operation code and the list of arguments.
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr);
+
+// Check if expr is same as x, or a sequence of Convert operations on x.
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
+
+// Strip away any top-level Convert operations (if any exist) and return
+// the input value. A ComplexConstructor(x, 0) is also considered as a
+// convert operation.
+// If the input is not Operation, Designator, FunctionRef or Constant,
+// it returns std::nullopt.
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
+
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index 69375a83dec25..f3cfa9b99fb4d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
 // Check for ambiguous USE associations
 bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
 
-// Checks whether the symbol on the LHS is present in the RHS expression.
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
-
-namespace operation {
-
-enum class Operator {
-  Unknown,
-  Add,
-  And,
-  Associated,
-  Call,
-  Constant,
-  Convert,
-  Div,
-  Eq,
-  Eqv,
-  False,
-  Ge,
-  Gt,
-  Identity,
-  Intrinsic,
-  Le,
-  Lt,
-  Max,
-  Min,
-  Mul,
-  Ne,
-  Neqv,
-  Not,
-  Or,
-  Pow,
-  Resize, // Convert within the same TypeCategory
-  Sub,
-  True,
-};
-
-std::string ToString(Operator op);
-
-template <typename... Ts, int Kind>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
-  switch (op.derived().logicalOperator) {
-  case common::LogicalOperator::And:
-    return Operator::And;
-  case common::LogicalOperator::Or:
-    return Operator::Or;
-  case common::LogicalOperator::Eqv:
-    return Operator::Eqv;
-  case common::LogicalOperator::Neqv:
-    return Operator::Neqv;
-  case common::LogicalOperator::Not:
-    return Operator::Not;
-  }
-  return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
-  switch (op.derived().opr) {
-  case common::RelationalOperator::LT:
-    return Operator::Lt;
-  case common::RelationalOperator::LE:
-    return Operator::Le;
-  case common::RelationalOperator::EQ:
-    return Operator::Eq;
-  case common::RelationalOperator::NE:
-    return Operator::Ne;
-  case common::RelationalOperator::GE:
-    return Operator::Ge;
-  case common::RelationalOperator::GT:
-    return Operator::Gt;
-  }
-  return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
-  return Operator::Add;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
-  return Operator::Sub;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
-  return Operator::Mul;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
-  return Operator::Div;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
-  return Operator::Pow;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
-  return Operator::Pow;
-}
-
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
-    const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
-  if constexpr (C == T::category) {
-    return Operator::Resize;
-  } else {
-    return Operator::Convert;
-  }
-}
-
-template <typename T> //
-Operator OperationCode(const evaluate::Constant<T> &x) {
-  return Operator::Constant;
-}
-
-template <typename T> //
-Operator OperationCode(const T &) {
-  return Operator::Unknown;
-}
-
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
-
-} // namespace operation
-
-/// Return information about the top-level operation (ignoring parentheses):
-/// the operation code and the list of arguments.
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
-    const SomeExpr &expr);
-
-/// Check if expr is same as x, or a sequence of Convert operations on x.
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
-
-/// Strip away any top-level Convert operations (if any exist) and return
-/// the input value. A ComplexConstructor(x, 0) is also considered as a
-/// convert operation.
-/// If the input is not Operation, Designator, FunctionRef or Constant,
-/// it returns std::nullopt.
-MaybeExpr GetConvertInput(const SomeExpr &x);
 } // namespace Fortran::semantics
 #endif // FORTRAN_SEMANTICS_TOOLS_H_
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 222c32a9c332e..68838564f87ba 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -13,6 +13,7 @@
 #include "flang/Evaluate/traverse.h"
 #include "flang/Parser/message.h"
 #include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSwitch.h"
 #include <algorithm>
 #include <variant>
 
@@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
   }
 }
 
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
+  if (lhs && rhs) {
+    if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
+      const Symbol &first{*lhsSymbols.front()};
+      for (const Symbol &symbol : GetSymbolVector(*rhs)) {
+        if (first == symbol) {
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
+namespace operation {
+template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
+  return AsGenericExpr(common::Clone(x));
+}
+
+template <bool IgnoreResizingConverts>
+struct ArgumentExtractor
+    : public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+          std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
+  using Arguments = std::vector<Expr<SomeType>>;
+  using Result = std::pair<operation::Operator, Arguments>;
+  using Base =
+      Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
+  static constexpr auto IgnoreResizes{IgnoreResizingConverts};
+  static constexpr auto Logical{common::TypeCategory::Logical};
+  ArgumentExtractor() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <int Kind>
+  Result operator()(const Constant<Type<Logical, Kind>> &x) const {
+    if (const auto &val{x.GetScalarValue()}) {
+      return val->IsTrue()
+          ? std::make_pair(operation::Operator::True, Arguments{})
+          : std::make_pair(operation::Operator::False, Arguments{});
+    }
+    return Default();
+  }
+
+  template <typename R> Result operator()(const FunctionRef<R> &x) const {
+    Result result{operation::OperationCode(x.proc()), {}};
+    for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+      if (auto *e{x.UnwrapArgExpr(i)}) {
+        result.second.push_back(*e);
+      }
+    }
+    return result;
+  }
+
+  template <typename D, typename R, typename... Os>
+  Result operator()(const Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, Parentheses<R>>) {
+      // Ignore top-level parentheses.
+      return (*this)(x.template operand<0>());
+    }
+    if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
+      // Ignore conversions within the same category.
+      // Atomic operations on int(kind=1) may be implicitly widened
+      // to int(kind=4) for example.
+      return (*this)(x.template operand<0>());
+    } else {
+      return std::make_pair(operation::OperationCode(x),
+          OperationArgs(x, std::index_sequence_for<Os...>{}));
+    }
+  }
+
+  template <typename T> Result operator()(const Designator<T> &x) const {
+    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+  }
+
+  template <typename T> Result operator()(const Constant<T> &x) const {
+    return {operation::Operator::Identity, {AsSomeExpr(x)}};
+  }
+
+  template <typename... Rs>
+  Result Combine(Result &&result, Rs &&...results) const {
+    // There shouldn't be any combining needed, since we're stopping the
+    // traversal at the top-level operation, but implement one that picks
+    // the first non-empty result.
+    if constexpr (sizeof...(Rs) == 0) {
+      return std::move(result);
+    } else {
+      if (!result.second.empty()) {
+        return std::move(result);
+      } else {
+        return Combine(std::move(results)...);
+      }
+    }
+  }
+
+private:
+  template <typename D, typename R, typename... Os, size_t... Is>
+  Arguments OperationArgs(
+      const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
+    return Arguments{Expr<SomeType>(x.template operand<Is>())...};
+  }
+};
+} // namespace operation
+
+std::string operation::ToString(operation::Operator op) {
+  switch (op) {
+  case Operator::Unknown:
+    return "??";
+  case Operator::Add:
+    return "+";
+  case Operator::And:
+    return "AND";
+  case Operator::Associated:
+    return "ASSOCIATED";
+  case Operator::Call:
+    return "function-call";
+  case Operator::Constant:
+    return "constant";
+  case Operator::Convert:
+    return "type-conversion";
+  case Operator::Div:
+    return "/";
+  case Operator::Eq:
+    return "==";
+  case Operator::Eqv:
+    return "EQV";
+  case Operator::False:
+    return ".FALSE.";
+  case Operator::Ge:
+    return ">=";
+  case Operator::Gt:
+    return ">";
+  case Operator::Identity:
+    return "identity";
+  case Operator::Intrinsic:
+    return "intrinsic";
+  case Operator::Le:
+    return "<=";
+  case Operator::Lt:
+    return "<";
+  case Operator::Max:
+    return "MAX";
+  case Operator::Min:
+    return "MIN";
+  case Operator::Mul:
+    return "*";
+  case Operator::Ne:
+    return "/=";
+  case Operator::Neqv:
+    return "NEQV/EOR";
+  case Operator::Not:
+    return "NOT";
+  case Operator::Or:
+    return "OR";
+  case Operator::Pow:
+    return "**";
+  case Operator::Resize:
+    return "resize";
+  case Operator::Sub:
+    return "-";
+  case Operator::True:
+    return ".TRUE.";
+  }
+  llvm_unreachable("Unhandler operator");
+}
+
+operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
+  Operator code{llvm::StringSwitch<Operator>(proc.GetName())
+          .Case("associated", Operator::Associated)
+          .Case("min", Operator::Min)
+          .Case("max", Operator::Max)
+          .Case("iand", Operator::And)
+          .Case("ior", Operator::Or)
+          .Case("ieor", Operator::Neqv)
+          .Default(Operator::Call)};
+  if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+    return Operator::Intrinsic;
+  }
+  return code;
+}
+
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr) {
+  return operation::ArgumentExtractor<true>{}(expr);
+}
+
+namespace operation {
+struct ConvertCollector
+    : public Traverse<ConvertCollector,
+          std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
+          false> {
+  using Result =
+      std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
+  using Base = Traverse<ConvertCollector, Result, false>;
+  ConvertCollector() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <typename T> Result operator()(const Designator<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename T> Result operator()(const FunctionRef<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename T> Result operator()(const Constant<T> &x) const {
+    return {AsSomeExpr(x), {}};
+  }
+
+  template <typename D, typename R, typename... Os>
+  Result operator()(const Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, Parentheses<R>>) {
+      // Ignore parentheses.
+      return (*this)(x.template operand<0>());
+    } else if constexpr (is_convert_v<D>) {
+      // Convert should always have a typed result, so it should be safe to
+      // dereference x.GetType().
+      return Combine(
+          {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+    } else if constexpr (is_complex_constructor_v<D>) {
+      // This is a conversion iff the imaginary operand is 0.
+      if (IsZero(x.template operand<1>())) {
+        return Combine(
+            {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+      } else {
+        return {AsSomeExpr(x.derived()), {}};
+      }
+    } else {
+      return {AsSomeExpr(x.derived()), {}};
+    }
+  }
+
+  template <typename... Rs>
+  Result Combine(Result &&result, Rs &&...results) const {
+    Result v(std::move(result));
+    auto setValue{[](std::optional<Expr<SomeType>> &x,
+                      std::optional<Expr<SomeType>> &&y) {
+      assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+      if (!x.has_value()) {
+        x = std::move(y);
+      }
+    }};
+    auto moveAppend{[](auto &accum, auto &&other) {
+      for (auto &&s : other) {
+        accum.push_back(std::move(s));
+      }
+    }};
+    (setValue(v.first, std::move(results).first), ...);
+    (moveAppend(v.second, std::move(results).second), ...);
+    return v;
+  }
+
+private:
+  template <typename A> static bool IsZero(const A &x) { return false; }
+  template <typename T> static bool IsZero(const Expr<T> &x) {
+    return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+  }
+  template <typename T> static bool IsZero(const Constant<T> &x) {
+    if (auto &&maybeScalar{x.GetScalarValue()}) {
+      return maybeScalar->IsZero();
+    } else {
+      return false;
+    }
+  }
+
+  template <typename T> struct is_convert {
+    static constexpr bool value{false};
+  };
+  template <typename T, common::TypeCategory C>
+  struct is_convert<Convert<T, C>> {
+    static constexpr bool value{true};
+  };
+  template <int K> struct is_convert<ComplexComponent<K>> {
+    // Conversion from complex to real.
+    static constexpr bool value{true};
+  };
+  template <typename T>
+  static constexpr bool is_convert_v{is_convert<T>::value};
+
+  template <typename T> struct is_complex_constructor {
+    static constexpr bool value{false};
+  };
+  template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
+    static constexpr bool value{true};
+  };
+  template <typename T>
+  static constexpr bool is_complex_constructor_v{
+      is_complex_constructor<T>::value};
+};
+} // namespace operation
+
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
+  // This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
+  return operation::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
+  // Check if expr is same as x, or a sequence of Convert operations on x.
+  if (expr == x) {
+    return true;
+  } else if (auto maybe{GetConvertInput(expr)}) {
+    return *maybe == x;
+  } else {
+    return false;
+  }
+}
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 69e9c53baa740..3ef3330cba2d6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Block &block = atomicCaptureOp->getRegion(0).back();
   firOpBuilder.setInsertionPointToStart(&block);
   if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
-    if (Fortran::semantics::CheckForSymbolMatch(
+    if (Fortran::evaluate::CheckForSymbolMatch(
             Fortran::semantics::GetExpr(stmt2Var),
             Fortran::semantics::GetExpr(stmt2Expr))) {
       // Atomic capture construct is of the form [capture-stmt, update-stmt]
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 82673f0948a5b..0acfd5b0a2534 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc,
   mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
 
   // This must ex...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144551


More information about the flang-commits mailing list