[flang-commits] [flang] daa5da0 - [flang] Don't blow up when combining mixed COMPLEX operations (#66235)

via flang-commits flang-commits at lists.llvm.org
Wed Sep 13 16:34:27 PDT 2023


Author: Peter Klausler
Date: 2023-09-13T16:34:23-07:00
New Revision: daa5da063ae8b39efa7368475f33db3313b41e30

URL: https://github.com/llvm/llvm-project/commit/daa5da063ae8b39efa7368475f33db3313b41e30
DIFF: https://github.com/llvm/llvm-project/commit/daa5da063ae8b39efa7368475f33db3313b41e30.diff

LOG: [flang] Don't blow up when combining mixed COMPLEX operations (#66235)

Expression processing applies some straightforward rewriting of mixed
complex/real and complex/integer operations to avoid having to promote
the real/integer operand to complex and then perform a complex
operation; for example, (a,b)+x becomes (a+x,b) rather than (a,b)+(x,0).
But this can blow up the expression representation when the complex
operand cannot be duplicated cheaply. So apply this technique only to
complex operands that are appropriate to duplicate.

Fixes https://github.com/llvm/llvm-project/issues/65142.

Added: 
    flang/test/Evaluate/bug65142.f90

Modified: 
    flang/include/flang/Evaluate/tools.h
    flang/lib/Evaluate/tools.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 71fe1237efdde7c..69730286767ce95 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -149,16 +149,6 @@ common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
 
 Expr<SomeType> Parenthesize(Expr<SomeType> &&);
 
-Expr<SomeReal> GetComplexPart(
-    const Expr<SomeComplex> &, bool isImaginary = false);
-Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);
-
-template <int KIND>
-Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
-    Expr<Type<TypeCategory::Real, KIND>> &&im) {
-  return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
-}
-
 template <typename A> constexpr bool IsNumericCategoryExpr() {
   if constexpr (common::HasMember<A, TypelessExpression>) {
     return false;

diff  --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index aadbc0804b342a7..a4afc3db06022e2 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -180,8 +180,9 @@ std::optional<Expr<SomeType>> Package(
     std::optional<Expr<SomeKind<CAT>>> &&catExpr) {
   if (catExpr) {
     return {AsGenericExpr(std::move(*catExpr))};
+  } else {
+    return std::nullopt;
   }
-  return NoExpr();
 }
 
 // Mixed REAL+INTEGER operations.  REAL**INTEGER is a special case that
@@ -204,6 +205,12 @@ std::optional<Expr<SomeType>> MixedRealLeft(
       std::move(rx.u)));
 }
 
+template <int KIND>
+Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
+    Expr<Type<TypeCategory::Real, KIND>> &&im) {
+  return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
+}
+
 std::optional<Expr<SomeComplex>> ConstructComplex(
     parser::ContextualMessages &messages, Expr<SomeType> &&real,
     Expr<SomeType> &&imaginary, int defaultRealKind) {
@@ -228,24 +235,87 @@ std::optional<Expr<SomeComplex>> ConstructComplex(
   return std::nullopt;
 }
 
-Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
-  return common::visit(
-      [&](const auto &zk) {
-        static constexpr int kind{ResultType<decltype(zk)>::kind};
-        return AsCategoryExpr(ComplexComponent<kind>{isImaginary, zk});
-      },
-      z.u);
-}
+// Extracts the real or imaginary part of the result of a COMPLEX
+// expression, when that expression is simple enough to be duplicated.
+template <bool GET_IMAGINARY> struct ComplexPartExtractor {
+  template <typename A> static std::optional<Expr<SomeReal>> Get(const A &) {
+    return std::nullopt;
+  }
 
-Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
-  return common::visit(
-      [&](auto &&zk) {
-        static constexpr int kind{ResultType<decltype(zk)>::kind};
-        return AsCategoryExpr(
-            ComplexComponent<kind>{isImaginary, std::move(zk)});
-      },
-      z.u);
-}
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Parentheses<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Parentheses<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Negate<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Negate<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Convert<Type<TypeCategory::Complex, KIND>, TypeCategory::Complex>
+          &kz) {
+    if (auto x{Get(kz.left())}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Convert<Type<TypeCategory::Real, KIND>, TypeCategory::Real>{
+              AsGenericExpr(std::move(*x))}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(const ComplexConstructor<KIND> &kz) {
+    return GET_IMAGINARY ? Get(kz.right()) : Get(kz.left());
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Constant<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (auto cz{kz.GetScalarValue()}) {
+      return AsGenericExpr(
+          AsSpecificExpr(GET_IMAGINARY ? cz->AIMAG() : cz->REAL()));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Designator<Type<TypeCategory::Complex, KIND>> &kz) {
+    if (const auto *symbolRef{std::get_if<SymbolRef>(&kz.u)}) {
+      return AsGenericExpr(AsSpecificExpr(
+          Designator<Type<TypeCategory::Complex, KIND>>{ComplexPart{
+              DataRef{*symbolRef},
+              GET_IMAGINARY ? ComplexPart::Part::IM : ComplexPart::Part::RE}}));
+    } else {
+      return std::nullopt;
+    }
+  }
+
+  template <int KIND>
+  static std::optional<Expr<SomeReal>> Get(
+      const Expr<Type<TypeCategory::Complex, KIND>> &kz) {
+    return Get(kz.u);
+  }
+
+  static std::optional<Expr<SomeReal>> Get(const Expr<SomeComplex> &z) {
+    return Get(z.u);
+  }
+};
 
 // Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
 // and then applying complex operand promotion rules allows the result to have
@@ -266,19 +336,31 @@ Expr<SomeComplex> PromoteRealToComplex(Expr<SomeReal> &&someX) {
 // corresponding COMPLEX+COMPLEX operation.
 template <template <typename> class OPR, TypeCategory RCAT>
 std::optional<Expr<SomeType>> MixedComplexLeft(
-    parser::ContextualMessages &messages, Expr<SomeComplex> &&zx,
-    Expr<SomeKind<RCAT>> &&iry, [[maybe_unused]] int defaultRealKind) {
-  Expr<SomeReal> zr{GetComplexPart(zx, false)};
-  Expr<SomeReal> zi{GetComplexPart(zx, true)};
-  if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
+    parser::ContextualMessages &messages, const Expr<SomeComplex> &zx,
+    const Expr<SomeKind<RCAT>> &iry, [[maybe_unused]] int defaultRealKind) {
+  if constexpr (RCAT == TypeCategory::Integer &&
+      std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
+    // COMPLEX**INTEGER is a special case that doesn't convert the exponent.
+    return Package(common::visit(
+        [&](const auto &zxk) {
+          using Ty = ResultType<decltype(zxk)>;
+          return AsCategoryExpr(AsExpr(
+              RealToIntPower<Ty>{common::Clone(zxk), common::Clone(iry)}));
+        },
+        zx.u));
+  }
+  std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zx)};
+  std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zx)};
+  if (!zr || !zi) {
+  } else if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
       std::is_same_v<OPR<LargestReal>, Subtract<LargestReal>>) {
     // (a,b) + x -> (a+x, b)
     // (a,b) - x -> (a-x, b)
     if (std::optional<Expr<SomeType>> rr{
-            NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
-                AsGenericExpr(std::move(iry)), defaultRealKind)}) {
+            NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
+                AsGenericExpr(common::Clone(iry)), defaultRealKind)}) {
       return Package(ConstructComplex(messages, std::move(*rr),
-          AsGenericExpr(std::move(zi)), defaultRealKind));
+          AsGenericExpr(std::move(*zi)), defaultRealKind));
     }
   } else if constexpr (allowOperandDuplication &&
       (std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>> ||
@@ -286,36 +368,16 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
     // (a,b) * x -> (a*x, b*x)
     // (a,b) / x -> (a/x, b/x)
     auto copy{iry};
-    auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
-        AsGenericExpr(std::move(iry)), defaultRealKind)};
-    auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zi)),
+    auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
+        AsGenericExpr(common::Clone(iry)), defaultRealKind)};
+    auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zi)),
         AsGenericExpr(std::move(copy)), defaultRealKind)};
     if (auto parts{common::AllPresent(std::move(rr), std::move(ri))}) {
       return Package(ConstructComplex(messages, std::get<0>(std::move(*parts)),
           std::get<1>(std::move(*parts)), defaultRealKind));
     }
-  } else if constexpr (RCAT == TypeCategory::Integer &&
-      std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
-    // COMPLEX**INTEGER is a special case that doesn't convert the exponent.
-    static_assert(RCAT == TypeCategory::Integer);
-    return Package(common::visit(
-        [&](auto &&zxk) {
-          using Ty = ResultType<decltype(zxk)>;
-          return AsCategoryExpr(
-              AsExpr(RealToIntPower<Ty>{std::move(zxk), std::move(iry)}));
-        },
-        std::move(zx.u)));
-  } else {
-    // (a,b) ** x -> (a,b) ** (x,0)
-    if constexpr (RCAT == TypeCategory::Integer) {
-      Expr<SomeComplex> zy{ConvertTo(zx, std::move(iry))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    } else {
-      Expr<SomeComplex> zy{PromoteRealToComplex(std::move(iry))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    }
   }
-  return NoExpr();
+  return std::nullopt;
 }
 
 // Mixed COMPLEX operations with the COMPLEX operand on the right.
@@ -325,39 +387,49 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
 //  x / (a,b) -> (x,0) / (a,b)   (and **)
 template <template <typename> class OPR, TypeCategory LCAT>
 std::optional<Expr<SomeType>> MixedComplexRight(
-    parser::ContextualMessages &messages, Expr<SomeKind<LCAT>> &&irx,
-    Expr<SomeComplex> &&zy, [[maybe_unused]] int defaultRealKind) {
+    parser::ContextualMessages &messages, const Expr<SomeKind<LCAT>> &irx,
+    const Expr<SomeComplex> &zy, [[maybe_unused]] int defaultRealKind) {
   if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>>) {
     // x + (a,b) -> (a,b) + x -> (a+x, b)
-    return MixedComplexLeft<OPR, LCAT>(
-        messages, std::move(zy), std::move(irx), defaultRealKind);
+    return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
   } else if constexpr (allowOperandDuplication &&
       std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>>) {
     // x * (a,b) -> (a,b) * x -> (a*x, b*x)
-    return MixedComplexLeft<OPR, LCAT>(
-        messages, std::move(zy), std::move(irx), defaultRealKind);
+    return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
   } else if constexpr (std::is_same_v<OPR<LargestReal>,
                            Subtract<LargestReal>>) {
     // x - (a,b) -> (x-a, -b)
-    Expr<SomeReal> zr{GetComplexPart(zy, false)};
-    Expr<SomeReal> zi{GetComplexPart(zy, true)};
-    if (std::optional<Expr<SomeType>> rr{
-            NumericOperation<Subtract>(messages, AsGenericExpr(std::move(irx)),
-                AsGenericExpr(std::move(zr)), defaultRealKind)}) {
-      return Package(ConstructComplex(messages, std::move(*rr),
-          AsGenericExpr(-std::move(zi)), defaultRealKind));
-    }
-  } else {
-    // x / (a,b) -> (x,0) / (a,b)
-    if constexpr (LCAT == TypeCategory::Integer) {
-      Expr<SomeComplex> zx{ConvertTo(zy, std::move(irx))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
-    } else {
-      Expr<SomeComplex> zx{PromoteRealToComplex(std::move(irx))};
-      return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
+    std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zy)};
+    std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zy)};
+    if (zr && zi) {
+      if (std::optional<Expr<SomeType>> rr{NumericOperation<Subtract>(messages,
+              AsGenericExpr(common::Clone(irx)), AsGenericExpr(std::move(*zr)),
+              defaultRealKind)}) {
+        return Package(ConstructComplex(messages, std::move(*rr),
+            AsGenericExpr(-std::move(*zi)), defaultRealKind));
+      }
     }
   }
-  return NoExpr();
+  return std::nullopt;
+}
+
+// Promotes REAL(rk) and COMPLEX(zk) operands COMPLEX(max(rk,zk))
+// then combine them with an operator.
+template <template <typename> class OPR, TypeCategory XCAT, TypeCategory YCAT>
+Expr<SomeComplex> PromoteMixedComplexReal(
+    Expr<SomeKind<XCAT>> &&x, Expr<SomeKind<YCAT>> &&y) {
+  static_assert(XCAT == TypeCategory::Complex || YCAT == TypeCategory::Complex);
+  static_assert(XCAT == TypeCategory::Real || YCAT == TypeCategory::Real);
+  return common::visit(
+      [&](const auto &kx, const auto &ky) {
+        constexpr int maxKind{std::max(
+            ResultType<decltype(kx)>::kind, ResultType<decltype(ky)>::kind)};
+        using ZTy = Type<TypeCategory::Complex, maxKind>;
+        return Expr<SomeComplex>{
+            Expr<ZTy>{OPR<ZTy>{ConvertToType<ZTy>(std::move(x)),
+                ConvertToType<ZTy>(std::move(y))}}};
+      },
+      x.u, y.u);
 }
 
 // N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
@@ -397,20 +469,40 @@ std::optional<Expr<SomeType>> NumericOperation(
                 std::move(zx), std::move(zy)));
           },
           [&](Expr<SomeComplex> &&zx, Expr<SomeInteger> &&iy) {
-            return MixedComplexLeft<OPR>(
-                messages, std::move(zx), std::move(iy), defaultRealKind);
+            if (auto result{
+                    MixedComplexLeft<OPR>(messages, zx, iy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
+                  std::move(zx), ConvertTo(zx, std::move(iy))));
+            }
           },
           [&](Expr<SomeComplex> &&zx, Expr<SomeReal> &&ry) {
-            return MixedComplexLeft<OPR>(
-                messages, std::move(zx), std::move(ry), defaultRealKind);
+            if (auto result{
+                    MixedComplexLeft<OPR>(messages, zx, ry, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(
+                  PromoteMixedComplexReal<OPR>(std::move(zx), std::move(ry)));
+            }
           },
           [&](Expr<SomeInteger> &&ix, Expr<SomeComplex> &&zy) {
-            return MixedComplexRight<OPR>(
-                messages, std::move(ix), std::move(zy), defaultRealKind);
+            if (auto result{MixedComplexRight<OPR>(
+                    messages, ix, zy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
+                  ConvertTo(zy, std::move(ix)), std::move(zy)));
+            }
           },
           [&](Expr<SomeReal> &&rx, Expr<SomeComplex> &&zy) {
-            return MixedComplexRight<OPR>(
-                messages, std::move(rx), std::move(zy), defaultRealKind);
+            if (auto result{MixedComplexRight<OPR>(
+                    messages, rx, zy, defaultRealKind)}) {
+              return result;
+            } else {
+              return Package(
+                  PromoteMixedComplexReal<OPR>(std::move(rx), std::move(zy)));
+            }
           },
           // Operations with one typeless operand
           [&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
@@ -433,7 +525,6 @@ std::optional<Expr<SomeType>> NumericOperation(
           },
           // Default case
           [&](auto &&, auto &&) {
-            // TODO: defined operator
             messages.Say("non-numeric operands to numeric operation"_err_en_US);
             return NoExpr();
           },
@@ -481,17 +572,14 @@ std::optional<Expr<SomeType>> Negation(
           [&](Expr<SomeReal> &&x) { return Package(-std::move(x)); },
           [&](Expr<SomeComplex> &&x) { return Package(-std::move(x)); },
           [&](Expr<SomeCharacter> &&) {
-            // TODO: defined operator
             messages.Say("CHARACTER cannot be negated"_err_en_US);
             return NoExpr();
           },
           [&](Expr<SomeLogical> &&) {
-            // TODO: defined operator
             messages.Say("LOGICAL cannot be negated"_err_en_US);
             return NoExpr();
           },
           [&](Expr<SomeDerived> &&) {
-            // TODO: defined operator
             messages.Say("Operand cannot be negated"_err_en_US);
             return NoExpr();
           },
@@ -643,8 +731,7 @@ std::optional<Expr<SomeType>> ConvertToType(
       if (auto length{type.GetCharLength()}) {
         converted = common::visit(
             [&](auto &&x) {
-              using Ty = std::decay_t<decltype(x)>;
-              using CharacterType = typename Ty::Result;
+              using CharacterType = ResultType<decltype(x)>;
               return Expr<SomeCharacter>{
                   Expr<CharacterType>{SetLength<CharacterType::kind>{
                       std::move(x), std::move(*length)}}};
@@ -1099,7 +1186,7 @@ static std::optional<Expr<SomeType>> DataConstantConversionHelper(
     if (const auto *someExpr{UnwrapExpr<Expr<SomeKind<FROM>>>(*sized)}) {
       return common::visit(
           [](const auto &w) -> std::optional<Expr<SomeType>> {
-            using FromType = typename std::decay_t<decltype(w)>::Result;
+            using FromType = ResultType<decltype(w)>;
             static constexpr int kind{FromType::kind};
             if constexpr (IsValidKindOfIntrinsicType(TO, kind)) {
               if (const auto *fromConst{UnwrapExpr<Constant<FromType>>(w)}) {

diff  --git a/flang/test/Evaluate/bug65142.f90 b/flang/test/Evaluate/bug65142.f90
new file mode 100644
index 000000000000000..e9bac4f5bbe0cb5
--- /dev/null
+++ b/flang/test/Evaluate/bug65142.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+! Ensure that expression rewriting doesn't blow up when many
+! mixed complex operations are combined.
+! The result of folding (8,1.542956)**9 is checked approximately
+! to allow for 
diff ering implementations of cpowi.
+PROGRAM ProgramName0
+ COMPLEX complexVar0
+ REAL realVar0
+ INTEGER intVar0
+ complexVar0 = (8,1.542956)**9/intVar0/realVar0+6+intVar0+intVar0**5-4&
+ &+intVar0**intVar0/8**intVar0/intVar&
+ &0/intVar0+0/3-3+9/9/5+7+5**0*0**10-0/2**1**2-4+intVar0*intVar0-3**intVar0-6
+!CHECK: complexvar0=(-2.{{[0-9]*}}e7_4,1.{{[0-9]*}}e8_4)/(real(intvar0,kind=4),0._4)/(realvar0,0._4)+(6._4,0._4)+(real(intvar0,kind=4),0._4)+(real(intvar0**5_4,kind=4),0._4)-(4._4,0._4)+(real(intvar0**intvar0/8_4**intvar0/intvar0/intvar0,kind=4),0._4)+(0._4,0._4)-(3._4,0._4)+(0._4,0._4)+(7._4,0._4)+(0._4,0._4)-(0._4,0._4)-(4._4,0._4)+(real(intvar0*intvar0,kind=4),0._4)-(real(3_4**intvar0,kind=4),0._4)-(6._4,0._4)
+END PROGRAM


        


More information about the flang-commits mailing list