[flang-commits] [flang] 3a26596 - [flang] Fold complex component references
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Wed May 11 10:04:21 PDT 2022
Author: Peter Klausler
Date: 2022-05-11T10:04:13-07:00
New Revision: 3a26596af3613a2ede294ed017f2c05e48255713
URL: https://github.com/llvm/llvm-project/commit/3a26596af3613a2ede294ed017f2c05e48255713
DIFF: https://github.com/llvm/llvm-project/commit/3a26596af3613a2ede294ed017f2c05e48255713.diff
LOG: [flang] Fold complex component references
Complex component references (z%RE, z%IM) of complex named constants
should be evaluated at compilation time.
Differential Revision: https://reviews.llvm.org/D125341
Added:
flang/test/Evaluate/fold-re-im.f90
Modified:
flang/include/flang/Evaluate/tools.h
flang/include/flang/Evaluate/type.h
flang/include/flang/Evaluate/variable.h
flang/lib/Evaluate/fold-implementation.h
flang/lib/Evaluate/fold-real.cpp
flang/lib/Evaluate/tools.cpp
flang/lib/Semantics/expression.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index ac75de8899bb..403625cc4cf0 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -149,6 +149,7 @@ Expr<SomeType> Parenthesize(Expr<SomeType> &&);
Expr<SomeReal> GetComplexPart(
const Expr<SomeComplex> &, bool isImaginary = false);
+Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);
template <int KIND>
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index 08c9e94c9d89..b7270b2682ff 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -461,7 +461,6 @@ int SelectedRealKind(
#define EXPAND_FOR_EACH_CHARACTER_KIND(M, P, S) M(P, S, 1) M(P, S, 2) M(P, S, 4)
#define EXPAND_FOR_EACH_LOGICAL_KIND(M, P, S) \
M(P, S, 1) M(P, S, 2) M(P, S, 4) M(P, S, 8)
-#define TEMPLATE_INSTANTIATION(P, S, ARG) P<ARG> S;
#define FOR_EACH_INTEGER_KIND_HELP(PREFIX, SUFFIX, K) \
PREFIX<Type<TypeCategory::Integer, K>> SUFFIX;
diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h
index 0a689473cfc9..dac51edf2f42 100644
--- a/flang/include/flang/Evaluate/variable.h
+++ b/flang/include/flang/Evaluate/variable.h
@@ -353,6 +353,7 @@ class ComplexPart {
ENUM_CLASS(Part, RE, IM)
CLASS_BOILERPLATE(ComplexPart)
ComplexPart(DataRef &&z, Part p) : complex_{std::move(z)}, part_{p} {}
+ DataRef &complex() { return complex_; }
const DataRef &complex() const { return complex_; }
Part part() const { return part_; }
int Rank() const;
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 317575ef9112..04295d31c619 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -58,6 +58,7 @@ template <typename T> class Folder {
std::optional<Constant<T>> GetConstantComponent(
Component &, const std::vector<Constant<SubscriptInteger>> * = nullptr);
std::optional<Constant<T>> Folding(ArrayRef &);
+ std::optional<Constant<T>> Folding(DataRef &);
Expr<T> Folding(Designator<T> &&);
Constant<T> *Folding(std::optional<ActualArgument> &);
@@ -118,27 +119,12 @@ CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&);
DataRef FoldOperation(FoldingContext &, DataRef &&);
Substring FoldOperation(FoldingContext &, Substring &&);
ComplexPart FoldOperation(FoldingContext &, ComplexPart &&);
-
template <typename T>
-Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&);
-template <int KIND>
-Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
- FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&);
-template <int KIND>
-Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
- FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&);
-template <int KIND>
-Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
- FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&);
-template <int KIND>
-Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
- FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&);
-
+Expr<T> FoldOperation(FoldingContext &, FunctionRef<T> &&);
template <typename T>
Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) {
return Folder<T>{context}.Folding(std::move(designator));
}
-
Expr<TypeParamInquiry::Result> FoldOperation(
FoldingContext &, TypeParamInquiry &&);
Expr<ImpliedDoIndex::Result> FoldOperation(
@@ -182,6 +168,25 @@ std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
}
}
+template <typename T>
+std::optional<Constant<T>> Folder<T>::Folding(DataRef &ref) {
+ return common::visit(
+ common::visitors{
+ [this](SymbolRef &sym) { return GetNamedConstant(*sym); },
+ [this](Component &comp) {
+ comp = FoldOperation(context_, std::move(comp));
+ return GetConstantComponent(comp);
+ },
+ [this](ArrayRef &aRef) {
+ aRef = FoldOperation(context_, std::move(aRef));
+ return Folding(aRef);
+ },
+ [](CoarrayRef &) { return std::optional<Constant<T>>{}; },
+ },
+ ref.u);
+}
+
+// TODO: This would be more natural as a member function of Constant<T>.
template <typename T>
std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
const std::vector<Constant<SubscriptInteger>> &subscripts) {
@@ -341,6 +346,19 @@ template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
}
}
}
+ } else if constexpr (T::category == TypeCategory::Real) {
+ if (auto *zPart{std::get_if<ComplexPart>(&designator.u)}) {
+ *zPart = FoldOperation(context_, std::move(*zPart));
+ using ComplexT = Type<TypeCategory::Complex, T::kind>;
+ if (auto zConst{Folder<ComplexT>{context_}.Folding(zPart->complex())}) {
+ return Fold(context_,
+ Expr<T>{ComplexComponent<T::kind>{
+ zPart->part() == ComplexPart::Part::IM,
+ Expr<ComplexT>{std::move(*zConst)}}});
+ } else {
+ return Expr<T>{Designator<T>{std::move(*zPart)}};
+ }
+ }
}
return common::visit(
common::visitors{
@@ -1045,6 +1063,20 @@ Expr<T> RewriteSpecificMINorMAX(
return common::visit(insertConversion, sx.u);
}
+// FoldIntrinsicFunction()
+template <int KIND>
+Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
+ FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&);
+template <int KIND>
+Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
+ FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&);
+template <int KIND>
+Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
+ FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&);
+template <int KIND>
+Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
+ FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&);
+
template <typename T>
Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
ActualArguments &args{funcRef.arguments()};
@@ -1922,6 +1954,31 @@ Expr<Type<TypeCategory::Real, KIND>> ToReal(
return result.value();
}
+// REAL(z) and AIMAG(z)
+template <int KIND>
+Expr<Type<TypeCategory::Real, KIND>> FoldOperation(
+ FoldingContext &context, ComplexComponent<KIND> &&x) {
+ using Operand = Type<TypeCategory::Complex, KIND>;
+ using Result = Type<TypeCategory::Real, KIND>;
+ if (auto array{ApplyElementwise(context, x,
+ std::function<Expr<Result>(Expr<Operand> &&)>{
+ [=](Expr<Operand> &&operand) {
+ return Expr<Result>{ComplexComponent<KIND>{
+ x.isImaginaryPart, std::move(operand)}};
+ }})}) {
+ return *array;
+ }
+ auto &operand{x.left()};
+ if (auto value{GetScalarConstantValue<Operand>(operand)}) {
+ if (x.isImaginaryPart) {
+ return Expr<Result>{Constant<Result>{value->AIMAG()}};
+ } else {
+ return Expr<Result>{Constant<Result>{value->REAL()}};
+ }
+ }
+ return Expr<Result>{std::move(x)};
+}
+
template <typename T>
Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) {
return common::visit(
@@ -1941,6 +1998,5 @@ Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) {
}
FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, )
-
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 667a5a0e1477..0cc6b91230e7 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -112,8 +112,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
common::die(" unexpected argument type inside abs");
}
} else if (name == "aimag") {
- return FoldElementalIntrinsic<T, ComplexT>(
- context, std::move(funcRef), &Scalar<ComplexT>::AIMAG);
+ if (auto *zExpr{UnwrapExpr<Expr<ComplexT>>(args[0])}) {
+ return Fold(context, Expr<T>{ComplexComponent{true, std::move(*zExpr)}});
+ }
} else if (name == "aint" || name == "anint") {
// ANINT rounds ties away from zero, not to even
common::RoundingMode mode{name == "aint"
@@ -318,31 +319,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return Expr<T>{std::move(funcRef)};
}
-template <int KIND>
-Expr<Type<TypeCategory::Real, KIND>> FoldOperation(
- FoldingContext &context, ComplexComponent<KIND> &&x) {
- using Operand = Type<TypeCategory::Complex, KIND>;
- using Result = Type<TypeCategory::Real, KIND>;
- if (auto array{ApplyElementwise(context, x,
- std::function<Expr<Result>(Expr<Operand> &&)>{
- [=](Expr<Operand> &&operand) {
- return Expr<Result>{ComplexComponent<KIND>{
- x.isImaginaryPart, std::move(operand)}};
- }})}) {
- return *array;
- }
- using Part = Type<TypeCategory::Real, KIND>;
- auto &operand{x.left()};
- if (auto value{GetScalarConstantValue<Operand>(operand)}) {
- if (x.isImaginaryPart) {
- return Expr<Part>{Constant<Part>{value->AIMAG()}};
- } else {
- return Expr<Part>{Constant<Part>{value->REAL()}};
- }
- }
- return Expr<Part>{std::move(x)};
-}
-
#ifdef _MSC_VER // disable bogus warning about missing definitions
#pragma warning(disable : 4661)
#endif
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 6dc4f6a9e611..394d1027a198 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -238,6 +238,16 @@ Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
z.u);
}
+Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
+ return common::visit(
+ [&](auto &&zk) {
+ static constexpr int kind{ResultType<decltype(zk)>::kind};
+ return AsCategoryExpr(
+ ComplexComponent<kind>{isImaginary, std::move(zk)});
+ },
+ z.u);
+}
+
// Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
// and then applying complex operand promotion rules allows the result to have
// the highest precision of REAL and COMPLEX operands as required by Fortran
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 7b3e37f358d1..da45ef4f163d 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1073,7 +1073,8 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::StructureComponent &sc) {
MiscKind kind{details->kind()};
if (kind == MiscKind::ComplexPartRe || kind == MiscKind::ComplexPartIm) {
if (auto *zExpr{std::get_if<Expr<SomeComplex>>(&base->u)}) {
- if (std::optional<DataRef> dataRef{ExtractDataRef(std::move(*zExpr))}) {
+ if (std::optional<DataRef> dataRef{ExtractDataRef(*zExpr)}) {
+ // Represent %RE/%IM as a designator
Expr<SomeReal> realExpr{common::visit(
[&](const auto &z) {
using PartType = typename ResultType<decltype(z)>::Part;
diff --git a/flang/test/Evaluate/fold-re-im.f90 b/flang/test/Evaluate/fold-re-im.f90
new file mode 100644
index 000000000000..f39989cdf973
--- /dev/null
+++ b/flang/test/Evaluate/fold-re-im.f90
@@ -0,0 +1,15 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of complex components
+module m
+ complex, parameter :: z = (1., 2.)
+ logical, parameter :: test_1 = z%re == 1.
+ logical, parameter :: test_2 = z%im == 2.
+ logical, parameter :: test_3 = real(z+z) == 2.
+ logical, parameter :: test_4 = aimag(z+z) == 4.
+ type :: t
+ complex :: z
+ end type
+ type(t), parameter :: tz(*) = [t((3., 4.)), t((5., 6.))]
+ logical, parameter :: test_5 = all(tz%z%re == [3., 5.])
+ logical, parameter :: test_6 = all(tz%z%im == [4., 6.])
+end module
More information about the flang-commits
mailing list