[flang-commits] [flang] [flang] Don't blow up when combining mixed COMPLEX operations (PR #66235)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Wed Sep 13 16:21:31 PDT 2023


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/66235:

>From ca6b9e152ee117240ac1fd058db577de5c410220 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Fri, 1 Sep 2023 09:11:43 -0700
Subject: [PATCH] [flang] Don't blow up when combining mixed COMPLEX operations

Expression processing applies some straightforward rewriting
of mixed complex/real and complex/integer operations to avoid
having to promote the real/integer operand to complex and then
perform a complex operation; for example, (a,b)+x becomes
(a+x,b) rather than (a,b)+(x,0).  But this can blow up the
expression representation when the complex operand cannot be
duplicated cheaply.  So apply this technique only to complex
operands that are appropriate to duplicate.

Fixes https://github.com/llvm/llvm-project/issues/65142.

Pull request: https://github.com/llvm/llvm-project/pull/66235
---
 flang/include/flang/Evaluate/tools.h |  10 -
 flang/lib/Evaluate/tools.cpp         | 263 ++++++++++++++++++---------
 flang/test/Evaluate/bug65142.f90     |  14 ++
 3 files changed, 189 insertions(+), 98 deletions(-)
 create mode 100644 flang/test/Evaluate/bug65142.f90

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 71fe1237efdde7c..69730286767ce95 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -149,16 +149,6 @@ common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
 
 Expr<SomeType> Parenthesize(Expr<SomeType> &&);
 
-Expr<SomeReal> GetComplexPart(
-    const Expr<SomeComplex> &, bool isImaginary = false);
-Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);
-
-template <int KIND>
-Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
-    Expr<Type<TypeCategory::Real, KIND>> &&im) {
-  return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
-}
-
 template <typename A> constexpr bool IsNumericCategoryExpr() {
   if constexpr (common::HasMember<A, TypelessExpression>) {
     return false;
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index aadbc0804b342a7..a4afc3db06022e2 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -180,8 +180,9 @@ std::optional<Expr<SomeType>> Package(
     std::optional<Expr<SomeKind<CAT>>> &&catExpr) {
   if (catExpr) {
     return {AsGenericExpr(std::move(*catExpr))};
+  } else {
+    return std::nullopt;
   }
-  return NoExpr();
 }
 
 // Mixed REAL+INTEGER operations.  REAL**INTEGER is a special case that
@@ -204,6 +205,12 @@ std::optional<Expr<SomeType>> MixedRealLeft(
       std::move(rx.u)));
 }
 
+template <int KIND>
+Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
+    Expr<Type<TypeCategory::Real, KIND>> &&im) {
+  return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
+}
+
 std::optional<Expr<SomeComplex>> ConstructComplex(
     parser::ContextualMessages &messages, Expr<SomeType> &&real,
     Expr<SomeType> &&imaginary, int defaultRealKind) {
@@ -228,24 +235,87 @@ std::optional<Expr<SomeComplex>> ConstructComplex(
   return std::nullopt;
 }
 
-Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
-  return common::visit(
-      [&](const auto &zk) {
-        static constexpr int kind{ResultType<decltype(zk)>::kind};
-        return AsCategoryExpr(ComplexComponent<kind>{isImaginary, zk});
-      },
-      z.u);
-}
+// Extracts the real or imaginary part of the result of a COMPLEX
+// expression, when that expression is simple enough to be duplicated.
+template <bool GET_IMAGINARY> struct ComplexPartExtractor {
+  template <typename A> static std::optional<Expr<SomeReal>> Get(const A &) {
+    return std::nullopt;
+  }
 
-Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
-  return common::visit(
-      [&](auto &&zk) {
-        static constexpr int kind{ResultType<decltype(zk)>::kind};
-        return AsCategoryExpr(
-            ComplexComponent<kind>{isImaginary, std::move(zk)});
-      },
-      z.u);
-}
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Parentheses<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Parentheses<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Negate<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Negate<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Convert<Type<TypeCategory::Complex, KIND>, TypeCategory::Complex>
+          &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Convert<Type<TypeCategory::Real, KIND>, TypeCategory::Real>{
+              AsGenericExpr(std::move(*x))}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(const ComplexConstructor<KIND> &kz) {
+    return GET_IMAGINARY ? Get(kz.right()) : Get(kz.left());
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Constant<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto cz{kz.GetScalarValue()}) {
+      return AsGenericExpr(
+          AsSpecificExpr(GET_IMAGINARY ? cz->AIMAG() : cz->REAL()));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Designator<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (const auto *symbolRef{std::get_if<SymbolRef>(&kz.u)}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Designator<Type<TypeCategory::Complex, KIND>>{ComplexPart{
+              DataRef{*symbolRef},
+              GET_IMAGINARY ? ComplexPart::Part::IM : ComplexPart::Part::RE}}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Expr<Type<TypeCategory::Complex, KIND>> &kz) {
+    return Get(kz.u);
+  }
+
+  static std::optional<Expr<SomeReal>> Get(const Expr<SomeComplex> &z) {
+    return Get(z.u);
+  }
+};
 
 // Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
 // and then applying complex operand promotion rules allows the result to have
@@ -266,19 +336,31 @@ Expr<SomeComplex> PromoteRealToComplex(Expr<SomeReal> &&someX) {
 // corresponding COMPLEX+COMPLEX operation.
 template <template <typename> class OPR, TypeCategory RCAT>
 std::optional<Expr<SomeType>> MixedComplexLeft(
-    parser::ContextualMessages &messages, Expr<SomeComplex> &&zx,
-    Expr<SomeKind<RCAT>> &&iry, [[maybe_unused]] int defaultRealKind) {
-  Expr<SomeReal> zr{GetComplexPart(zx, false)};
-  Expr<SomeReal> zi{GetComplexPart(zx, true)};
-  if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
+    parser::ContextualMessages &messages, const Expr<SomeComplex> &zx,
+    const Expr<SomeKind<RCAT>> &iry, [[maybe_unused]] int defaultRealKind) {
+  if constexpr (RCAT == TypeCategory::Integer &&
+      std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
+    // COMPLEX**INTEGER is a special case that doesn't convert the exponent.
+    return Package(common::visit(
+        [&](const auto &zxk) {
+          using Ty = ResultType<decltype(zxk)>;
+          return AsCategoryExpr(AsExpr(
+              RealToIntPower<Ty>{common::Clone(zxk), common::Clone(iry)}));
+        },
+        zx.u));
+  }
+  std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zx)};
+  std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zx)};
+  if (!zr || !zi) {
+  } else if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
       std::is_same_v<OPR<LargestReal>, Subtract<LargestReal>>) {
     // (a,b) + x -> (a+x, b)
     // (a,b) - x -> (a-x, b)
     if (std::optional<Expr<SomeType>> rr{
-            NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
-                AsGenericExpr(std::move(iry)), defaultRealKind)}) {
+            NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
+                AsGenericExpr(common::Clone(iry)), defaultRealKind)}) {
       return Package(ConstructComplex(messages, std::move(*rr),
-          AsGenericExpr(std::move(zi)), defaultRealKind));
+          AsGenericExpr(std::move(*zi)), defaultRealKind));
     }
   } else if constexpr (allowOperandDuplication &&
       (std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>> ||
@@ -286,36 +368,16 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
     // (a,b) * x -> (a*x, b*x)
     // (a,b) / x -> (a/x, b/x)
     auto copy{iry};
-    auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
-        AsGenericExpr(std::move(iry)), defaultRealKind)};
-    auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zi)),
+    auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
+        AsGenericExpr(common::Clone(iry)), defaultRealKind)};
+    auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zi)),
         AsGenericExpr(std::move(copy)), defaultRealKind)};
     if (auto parts{common::AllPresent(std::move(rr), std::move(ri))}) {
       return Package(ConstructComplex(messages, std::get<0>(std::move(*parts)),
           std::get<1>(std::move(*parts)), defaultRealKind));
     }
-  } else if constexpr (RCAT == TypeCategory::Integer &&
-      std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
-    // COMPLEX**INTEGER is a special case that doesn't convert the exponent.
-    static_assert(RCAT == TypeCategory::Integer);
-    return Package(common::visit(
-        [&](auto &&zxk) {
-          using Ty = ResultType<decltype(zxk)>;
-          return AsCategoryExpr(
-              AsExpr(RealToIntPower<Ty>{std::move(zxk), std::move(iry)}));
-        },
-        std::move(zx.u)));
-  } else {
-    // (a,b) ** x -> (a,b) ** (x,0)
-    if constexpr (RCAT == TypeCategory::Integer) {
-      Expr<SomeComplex> zy{ConvertTo(zx, std::move(iry))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    } else {
-      Expr<SomeComplex> zy{PromoteRealToComplex(std::move(iry))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    }
   }
-  return NoExpr();
+  return std::nullopt;
 }
 
 // Mixed COMPLEX operations with the COMPLEX operand on the right.
@@ -325,39 +387,49 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
 //  x / (a,b) -> (x,0) / (a,b)   (and **)
 template <template <typename> class OPR, TypeCategory LCAT>
 std::optional<Expr<SomeType>> MixedComplexRight(
-    parser::ContextualMessages &messages, Expr<SomeKind<LCAT>> &&irx,
-    Expr<SomeComplex> &&zy, [[maybe_unused]] int defaultRealKind) {
+    parser::ContextualMessages &messages, const Expr<SomeKind<LCAT>> &irx,
+    const Expr<SomeComplex> &zy, [[maybe_unused]] int defaultRealKind) {
   if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>>) {
     // x + (a,b) -> (a,b) + x -> (a+x, b)
-    return MixedComplexLeft<OPR, LCAT>(
-        messages, std::move(zy), std::move(irx), defaultRealKind);
+    return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
   } else if constexpr (allowOperandDuplication &&
       std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>>) {
     // x * (a,b) -> (a,b) * x -> (a*x, b*x)
-    return MixedComplexLeft<OPR, LCAT>(
-        messages, std::move(zy), std::move(irx), defaultRealKind);
+    return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
   } else if constexpr (std::is_same_v<OPR<LargestReal>,
                            Subtract<LargestReal>>) {
     // x - (a,b) -> (x-a, -b)
-    Expr<SomeReal> zr{GetComplexPart(zy, false)};
-    Expr<SomeReal> zi{GetComplexPart(zy, true)};
-    if (std::optional<Expr<SomeType>> rr{
-            NumericOperation<Subtract>(messages, AsGenericExpr(std::move(irx)),
-                AsGenericExpr(std::move(zr)), defaultRealKind)}) {
-      return Package(ConstructComplex(messages, std::move(*rr),
-          AsGenericExpr(-std::move(zi)), defaultRealKind));
-    }
-  } else {
-    // x / (a,b) -> (x,0) / (a,b)
-    if constexpr (LCAT == TypeCategory::Integer) {
-      Expr<SomeComplex> zx{ConvertTo(zy, std::move(irx))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    } else {
-      Expr<SomeComplex> zx{PromoteRealToComplex(std::move(irx))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
+    std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zy)};
+    std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zy)};
+    if (zr && zi) {
+      if (std::optional<Expr<SomeType>> rr{NumericOperation<Subtract>(messages,
+              AsGenericExpr(common::Clone(irx)), AsGenericExpr(std::move(*zr)),
+              defaultRealKind)}) {
+        return Package(ConstructComplex(messages, std::move(*rr),
+            AsGenericExpr(-std::move(*zi)), defaultRealKind));
+      }
     }
   }
-  return NoExpr();
+  return std::nullopt;
+}
+
+// Promotes REAL(rk) and COMPLEX(zk) operands COMPLEX(max(rk,zk))
+// then combine them with an operator.
+template <template <typename> class OPR, TypeCategory XCAT, TypeCategory YCAT>
+Expr<SomeComplex> PromoteMixedComplexReal(
+    Expr<SomeKind<XCAT>> &&x, Expr<SomeKind<YCAT>> &&y) {
+  static_assert(XCAT == TypeCategory::Complex || YCAT == TypeCategory::Complex);
+  static_assert(XCAT == TypeCategory::Real || YCAT == TypeCategory::Real);
+  return common::visit(
+      [&](const auto &kx, const auto &ky) {
+        constexpr int maxKind{std::max(
+            ResultType<decltype(kx)>::kind, ResultType<decltype(ky)>::kind)};
+        using ZTy = Type<TypeCategory::Complex, maxKind>;
+        return Expr<SomeComplex>{
+            Expr<ZTy>{OPR<ZTy>{ConvertToType<ZTy>(std::move(x)),
+                ConvertToType<ZTy>(std::move(y))}}};
+      },
+      x.u, y.u);
 }
 
 // N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
@@ -397,20 +469,40 @@ std::optional<Expr<SomeType>> NumericOperation(
                 std::move(zx), std::move(zy)));
           },
           [&](Expr<SomeComplex> &&zx, Expr<SomeInteger> &&iy) {
-            return MixedComplexLeft<OPR>(
-                messages, std::move(zx), std::move(iy), defaultRealKind);
+            if (auto result{
+                    MixedComplexLeft<OPR>(messages, zx, iy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
+                  std::move(zx), ConvertTo(zx, std::move(iy))));
+            }
           },
           [&](Expr<SomeComplex> &&zx, Expr<SomeReal> &&ry) {
-            return MixedComplexLeft<OPR>(
-                messages, std::move(zx), std::move(ry), defaultRealKind);
+            if (auto result{
+                    MixedComplexLeft<OPR>(messages, zx, ry, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(
+                  PromoteMixedComplexReal<OPR>(std::move(zx), std::move(ry)));
+            }
           },
           [&](Expr<SomeInteger> &&ix, Expr<SomeComplex> &&zy) {
-            return MixedComplexRight<OPR>(
-                messages, std::move(ix), std::move(zy), defaultRealKind);
+            if (auto result{MixedComplexRight<OPR>(
+                    messages, ix, zy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
+                  ConvertTo(zy, std::move(ix)), std::move(zy)));
+            }
           },
           [&](Expr<SomeReal> &&rx, Expr<SomeComplex> &&zy) {
-            return MixedComplexRight<OPR>(
-                messages, std::move(rx), std::move(zy), defaultRealKind);
+            if (auto result{MixedComplexRight<OPR>(
+                    messages, rx, zy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(
+                  PromoteMixedComplexReal<OPR>(std::move(rx), std::move(zy)));
+            }
           },
           // Operations with one typeless operand
           [&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
@@ -433,7 +525,6 @@ std::optional<Expr<SomeType>> NumericOperation(
           },
           // Default case
           [&](auto &&, auto &&) {
-            // TODO: defined operator
             messages.Say("non-numeric operands to numeric operation"_err_en_US);
             return NoExpr();
           },
@@ -481,17 +572,14 @@ std::optional<Expr<SomeType>> Negation(
           [&](Expr<SomeReal> &&x) { return Package(-std::move(x)); },
           [&](Expr<SomeComplex> &&x) { return Package(-std::move(x)); },
           [&](Expr<SomeCharacter> &&) {
-            // TODO: defined operator
             messages.Say("CHARACTER cannot be negated"_err_en_US);
             return NoExpr();
           },
           [&](Expr<SomeLogical> &&) {
-            // TODO: defined operator
             messages.Say("LOGICAL cannot be negated"_err_en_US);
             return NoExpr();
           },
           [&](Expr<SomeDerived> &&) {
-            // TODO: defined operator
             messages.Say("Operand cannot be negated"_err_en_US);
             return NoExpr();
           },
@@ -643,8 +731,7 @@ std::optional<Expr<SomeType>> ConvertToType(
       if (auto length{type.GetCharLength()}) {
         converted = common::visit(
             [&](auto &&x) {
-              using Ty = std::decay_t<decltype(x)>;
-              using CharacterType = typename Ty::Result;
+              using CharacterType = ResultType<decltype(x)>;
               return Expr<SomeCharacter>{
                   Expr<CharacterType>{SetLength<CharacterType::kind>{
                       std::move(x), std::move(*length)}}};
@@ -1099,7 +1186,7 @@ static std::optional<Expr<SomeType>> DataConstantConversionHelper(
     if (const auto *someExpr{UnwrapExpr<Expr<SomeKind<FROM>>>(*sized)}) {
       return common::visit(
           [](const auto &w) -> std::optional<Expr<SomeType>> {
-            using FromType = typename std::decay_t<decltype(w)>::Result;
+            using FromType = ResultType<decltype(w)>;
             static constexpr int kind{FromType::kind};
             if constexpr (IsValidKindOfIntrinsicType(TO, kind)) {
               if (const auto *fromConst{UnwrapExpr<Constant<FromType>>(w)}) {
diff --git a/flang/test/Evaluate/bug65142.f90 b/flang/test/Evaluate/bug65142.f90
new file mode 100644
index 000000000000000..e9bac4f5bbe0cb5
--- /dev/null
+++ b/flang/test/Evaluate/bug65142.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+! Ensure that expression rewriting doesn't blow up when many
+! mixed complex operations are combined.
+! The result of folding (8,1.542956)**9 is checked approximately
+! to allow for differing implementations of cpowi.
+PROGRAM ProgramName0
+ COMPLEX complexVar0
+ REAL realVar0
+ INTEGER intVar0
+ complexVar0 = (8,1.542956)**9/intVar0/realVar0+6+intVar0+intVar0**5-4&
+ &+intVar0**intVar0/8**intVar0/intVar&
+ &0/intVar0+0/3-3+9/9/5+7+5**0*0**10-0/2**1**2-4+intVar0*intVar0-3**intVar0-6
+!CHECK: complexvar0=(-2.{{[0-9]*}}e7_4,1.{{[0-9]*}}e8_4)/(real(intvar0,kind=4),0._4)/(realvar0,0._4)+(6._4,0._4)+(real(intvar0,kind=4),0._4)+(real(intvar0**5_4,kind=4),0._4)-(4._4,0._4)+(real(intvar0**intvar0/8_4**intvar0/intvar0/intvar0,kind=4),0._4)+(0._4,0._4)-(3._4,0._4)+(0._4,0._4)+(7._4,0._4)+(0._4,0._4)-(0._4,0._4)-(4._4,0._4)+(real(intvar0*intvar0,kind=4),0._4)-(real(3_4**intvar0,kind=4),0._4)-(6._4,0._4)
+END PROGRAM



More information about the flang-commits mailing list