[flang-commits] [flang] [flang][NFC] Move new code to right place (PR #144551)
via flang-commits
flang-commits at lists.llvm.org
Tue Jun 17 08:53:28 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
Some new code was added to flang/Semantics that only depends on facilities in flang/Evaluate. Move it into Evaluate and clean up some minor stylistic problems.
---
Patch is 32.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144551.diff
7 Files Affected:
- (modified) flang/include/flang/Evaluate/tools.h (+148)
- (modified) flang/include/flang/Semantics/tools.h (-149)
- (modified) flang/lib/Evaluate/tools.cpp (+311)
- (modified) flang/lib/Lower/OpenACC.cpp (+1-1)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+4-3)
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+3)
- (modified) flang/lib/Semantics/tools.cpp (-329)
``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 1959d5f3a5899..e04621f71f9a7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
}
+// Checks whether the symbol on the LHS is present in the RHS expression.
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
+
+namespace operation {
+
+enum class Operator {
+ Unknown,
+ Add,
+ And,
+ Associated,
+ Call,
+ Constant,
+ Convert,
+ Div,
+ Eq,
+ Eqv,
+ False,
+ Ge,
+ Gt,
+ Identity,
+ Intrinsic,
+ Le,
+ Lt,
+ Max,
+ Min,
+ Mul,
+ Ne,
+ Neqv,
+ Not,
+ Or,
+ Pow,
+ Resize, // Convert within the same TypeCategory
+ Sub,
+ True,
+};
+
+std::string ToString(Operator op);
+
+template <typename... Ts, int Kind>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
+ switch (op.derived().logicalOperator) {
+ case common::LogicalOperator::And:
+ return Operator::And;
+ case common::LogicalOperator::Or:
+ return Operator::Or;
+ case common::LogicalOperator::Eqv:
+ return Operator::Eqv;
+ case common::LogicalOperator::Neqv:
+ return Operator::Neqv;
+ case common::LogicalOperator::Not:
+ return Operator::Not;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
+ switch (op.derived().opr) {
+ case common::RelationalOperator::LT:
+ return Operator::Lt;
+ case common::RelationalOperator::LE:
+ return Operator::Le;
+ case common::RelationalOperator::EQ:
+ return Operator::Eq;
+ case common::RelationalOperator::NE:
+ return Operator::Ne;
+ case common::RelationalOperator::GE:
+ return Operator::Ge;
+ case common::RelationalOperator::GT:
+ return Operator::Gt;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+ return Operator::Add;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+ return Operator::Sub;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+ return Operator::Mul;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+ return Operator::Div;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, common::TypeCategory C, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+ if constexpr (C == T::category) {
+ return Operator::Resize;
+ } else {
+ return Operator::Convert;
+ }
+}
+
+template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+ return Operator::Constant;
+}
+
+template <typename T> Operator OperationCode(const T &) {
+ return Operator::Unknown;
+}
+
+Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+
+} // namespace operation
+
+// Return information about the top-level operation (ignoring parentheses):
+// the operation code and the list of arguments.
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr);
+
+// Check if expr is same as x, or a sequence of Convert operations on x.
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
+
+// Strip away any top-level Convert operations (if any exist) and return
+// the input value. A ComplexConstructor(x, 0) is also considered as a
+// convert operation.
+// If the input is not Operation, Designator, FunctionRef or Constant,
+// it returns std::nullopt.
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
+
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index 69375a83dec25..f3cfa9b99fb4d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
// Check for ambiguous USE associations
bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
-// Checks whether the symbol on the LHS is present in the RHS expression.
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
-
-namespace operation {
-
-enum class Operator {
- Unknown,
- Add,
- And,
- Associated,
- Call,
- Constant,
- Convert,
- Div,
- Eq,
- Eqv,
- False,
- Ge,
- Gt,
- Identity,
- Intrinsic,
- Le,
- Lt,
- Max,
- Min,
- Mul,
- Ne,
- Neqv,
- Not,
- Or,
- Pow,
- Resize, // Convert within the same TypeCategory
- Sub,
- True,
-};
-
-std::string ToString(Operator op);
-
-template <typename... Ts, int Kind>
-Operator OperationCode(
- const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
- switch (op.derived().logicalOperator) {
- case common::LogicalOperator::And:
- return Operator::And;
- case common::LogicalOperator::Or:
- return Operator::Or;
- case common::LogicalOperator::Eqv:
- return Operator::Eqv;
- case common::LogicalOperator::Neqv:
- return Operator::Neqv;
- case common::LogicalOperator::Not:
- return Operator::Not;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
- switch (op.derived().opr) {
- case common::RelationalOperator::LT:
- return Operator::Lt;
- case common::RelationalOperator::LE:
- return Operator::Le;
- case common::RelationalOperator::EQ:
- return Operator::Eq;
- case common::RelationalOperator::NE:
- return Operator::Ne;
- case common::RelationalOperator::GE:
- return Operator::Ge;
- case common::RelationalOperator::GT:
- return Operator::Gt;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
- return Operator::Add;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
- return Operator::Sub;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
- return Operator::Mul;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
- return Operator::Div;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
- if constexpr (C == T::category) {
- return Operator::Resize;
- } else {
- return Operator::Convert;
- }
-}
-
-template <typename T> //
-Operator OperationCode(const evaluate::Constant<T> &x) {
- return Operator::Constant;
-}
-
-template <typename T> //
-Operator OperationCode(const T &) {
- return Operator::Unknown;
-}
-
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
-
-} // namespace operation
-
-/// Return information about the top-level operation (ignoring parentheses):
-/// the operation code and the list of arguments.
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
- const SomeExpr &expr);
-
-/// Check if expr is same as x, or a sequence of Convert operations on x.
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
-
-/// Strip away any top-level Convert operations (if any exist) and return
-/// the input value. A ComplexConstructor(x, 0) is also considered as a
-/// convert operation.
-/// If the input is not Operation, Designator, FunctionRef or Constant,
-/// it returns std::nullopt.
-MaybeExpr GetConvertInput(const SomeExpr &x);
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_TOOLS_H_
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 222c32a9c332e..68838564f87ba 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -13,6 +13,7 @@
#include "flang/Evaluate/traverse.h"
#include "flang/Parser/message.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSwitch.h"
#include <algorithm>
#include <variant>
@@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
}
}
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
+ if (lhs && rhs) {
+ if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
+ const Symbol &first{*lhsSymbols.front()};
+ for (const Symbol &symbol : GetSymbolVector(*rhs)) {
+ if (first == symbol) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+namespace operation {
+template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
+ return AsGenericExpr(common::Clone(x));
+}
+
+template <bool IgnoreResizingConverts>
+struct ArgumentExtractor
+ : public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+ std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
+ using Arguments = std::vector<Expr<SomeType>>;
+ using Result = std::pair<operation::Operator, Arguments>;
+ using Base =
+ Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
+ static constexpr auto IgnoreResizes{IgnoreResizingConverts};
+ static constexpr auto Logical{common::TypeCategory::Logical};
+ ArgumentExtractor() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <int Kind>
+ Result operator()(const Constant<Type<Logical, Kind>> &x) const {
+ if (const auto &val{x.GetScalarValue()}) {
+ return val->IsTrue()
+ ? std::make_pair(operation::Operator::True, Arguments{})
+ : std::make_pair(operation::Operator::False, Arguments{});
+ }
+ return Default();
+ }
+
+ template <typename R> Result operator()(const FunctionRef<R> &x) const {
+ Result result{operation::OperationCode(x.proc()), {}};
+ for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+ if (auto *e{x.UnwrapArgExpr(i)}) {
+ result.second.push_back(*e);
+ }
+ }
+ return result;
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore top-level parentheses.
+ return (*this)(x.template operand<0>());
+ }
+ if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
+ // Ignore conversions within the same category.
+ // Atomic operations on int(kind=1) may be implicitly widened
+ // to int(kind=4) for example.
+ return (*this)(x.template operand<0>());
+ } else {
+ return std::make_pair(operation::OperationCode(x),
+ OperationArgs(x, std::index_sequence_for<Os...>{}));
+ }
+ }
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ // There shouldn't be any combining needed, since we're stopping the
+ // traversal at the top-level operation, but implement one that picks
+ // the first non-empty result.
+ if constexpr (sizeof...(Rs) == 0) {
+ return std::move(result);
+ } else {
+ if (!result.second.empty()) {
+ return std::move(result);
+ } else {
+ return Combine(std::move(results)...);
+ }
+ }
+ }
+
+private:
+ template <typename D, typename R, typename... Os, size_t... Is>
+ Arguments OperationArgs(
+ const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
+ return Arguments{Expr<SomeType>(x.template operand<Is>())...};
+ }
+};
+} // namespace operation
+
+std::string operation::ToString(operation::Operator op) {
+ switch (op) {
+ case Operator::Unknown:
+ return "??";
+ case Operator::Add:
+ return "+";
+ case Operator::And:
+ return "AND";
+ case Operator::Associated:
+ return "ASSOCIATED";
+ case Operator::Call:
+ return "function-call";
+ case Operator::Constant:
+ return "constant";
+ case Operator::Convert:
+ return "type-conversion";
+ case Operator::Div:
+ return "/";
+ case Operator::Eq:
+ return "==";
+ case Operator::Eqv:
+ return "EQV";
+ case Operator::False:
+ return ".FALSE.";
+ case Operator::Ge:
+ return ">=";
+ case Operator::Gt:
+ return ">";
+ case Operator::Identity:
+ return "identity";
+ case Operator::Intrinsic:
+ return "intrinsic";
+ case Operator::Le:
+ return "<=";
+ case Operator::Lt:
+ return "<";
+ case Operator::Max:
+ return "MAX";
+ case Operator::Min:
+ return "MIN";
+ case Operator::Mul:
+ return "*";
+ case Operator::Ne:
+ return "/=";
+ case Operator::Neqv:
+ return "NEQV/EOR";
+ case Operator::Not:
+ return "NOT";
+ case Operator::Or:
+ return "OR";
+ case Operator::Pow:
+ return "**";
+ case Operator::Resize:
+ return "resize";
+ case Operator::Sub:
+ return "-";
+ case Operator::True:
+ return ".TRUE.";
+ }
+ llvm_unreachable("Unhandler operator");
+}
+
+operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
+ Operator code{llvm::StringSwitch<Operator>(proc.GetName())
+ .Case("associated", Operator::Associated)
+ .Case("min", Operator::Min)
+ .Case("max", Operator::Max)
+ .Case("iand", Operator::And)
+ .Case("ior", Operator::Or)
+ .Case("ieor", Operator::Neqv)
+ .Default(Operator::Call)};
+ if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+ return Operator::Intrinsic;
+ }
+ return code;
+}
+
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr) {
+ return operation::ArgumentExtractor<true>{}(expr);
+}
+
+namespace operation {
+struct ConvertCollector
+ : public Traverse<ConvertCollector,
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
+ false> {
+ using Result =
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
+ using Base = Traverse<ConvertCollector, Result, false>;
+ ConvertCollector() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const FunctionRef<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore parentheses.
+ return (*this)(x.template operand<0>());
+ } else if constexpr (is_convert_v<D>) {
+ // Convert should always have a typed result, so it should be safe to
+ // dereference x.GetType().
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else if constexpr (is_complex_constructor_v<D>) {
+ // This is a conversion iff the imaginary operand is 0.
+ if (IsZero(x.template operand<1>())) {
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ Result v(std::move(result));
+ auto setValue{[](std::optional<Expr<SomeType>> &x,
+ std::optional<Expr<SomeType>> &&y) {
+ assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+ if (!x.has_value()) {
+ x = std::move(y);
+ }
+ }};
+ auto moveAppend{[](auto &accum, auto &&other) {
+ for (auto &&s : other) {
+ accum.push_back(std::move(s));
+ }
+ }};
+ (setValue(v.first, std::move(results).first), ...);
+ (moveAppend(v.second, std::move(results).second), ...);
+ return v;
+ }
+
+private:
+ template <typename A> static bool IsZero(const A &x) { return false; }
+ template <typename T> static bool IsZero(const Expr<T> &x) {
+ return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+ }
+ template <typename T> static bool IsZero(const Constant<T> &x) {
+ if (auto &&maybeScalar{x.GetScalarValue()}) {
+ return maybeScalar->IsZero();
+ } else {
+ return false;
+ }
+ }
+
+ template <typename T> struct is_convert {
+ static constexpr bool value{false};
+ };
+ template <typename T, common::TypeCategory C>
+ struct is_convert<Convert<T, C>> {
+ static constexpr bool value{true};
+ };
+ template <int K> struct is_convert<ComplexComponent<K>> {
+ // Conversion from complex to real.
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_convert_v{is_convert<T>::value};
+
+ template <typename T> struct is_complex_constructor {
+ static constexpr bool value{false};
+ };
+ template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_complex_constructor_v{
+ is_complex_constructor<T>::value};
+};
+} // namespace operation
+
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
+ // This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
+ return operation::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
+ // Check if expr is same as x, or a sequence of Convert operations on x.
+ if (expr == x) {
+ return true;
+ } else if (auto maybe{GetConvertInput(expr)}) {
+ return *maybe == x;
+ } else {
+ return false;
+ }
+}
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 69e9c53baa740..3ef3330cba2d6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
firOpBuilder.setInsertionPointToStart(&block);
if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
- if (Fortran::semantics::CheckForSymbolMatch(
+ if (Fortran::evaluate::CheckForSymbolMatch(
Fortran::semantics::GetExpr(stmt2Var),
Fortran::semantics::GetExpr(stmt2Expr))) {
// Atomic capture construct is of the form [capture-stmt, update-stmt]
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 82673f0948a5b..0acfd5b0a2534 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc,
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must ex...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144551
More information about the flang-commits
mailing list