[flang-commits] [flang] b617df6 - [flang] Avoid crashing from recursion on very tall expression parse trees

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Wed Feb 1 14:42:49 PST 2023


Author: Peter Klausler
Date: 2023-02-01T14:09:07-08:00
New Revision: b617df6e7cbc35230a4b7140cf66dabebe9700e7

URL: https://github.com/llvm/llvm-project/commit/b617df6e7cbc35230a4b7140cf66dabebe9700e7
DIFF: https://github.com/llvm/llvm-project/commit/b617df6e7cbc35230a4b7140cf66dabebe9700e7.diff

LOG: [flang] Avoid crashing from recursion on very tall expression parse trees

In the parse tree visitation framework (Parser/parse-tree-visitor.h)
and in the semantic analyzer for expressions (Semantics/expression.cpp)
avoid crashing due to stack size limitations by using an iterative
traversal algorithm rather than straightforward recursive tree walking.
The iterative approach is the obvious one of building a work queue and
using it to (in the case of the parse tree visitor) call the visitor
object's Pre() and Post() routines on subexpressions in the same order
as they would have been called during a recursive traversal.

This change helps the compiler survive some artificial stress tests
and perhaps with future exposure to machine-generated source code.

Differential Revision: https://reviews.llvm.org/D142771

Added: 
    flang/test/Evaluate/big-expr-tree.F90

Modified: 
    flang/include/flang/Parser/parse-tree-visitor.h
    flang/include/flang/Semantics/expression.h
    flang/lib/Semantics/expression.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Parser/parse-tree-visitor.h b/flang/include/flang/Parser/parse-tree-visitor.h
index 4e749d3c8b4e1..75466e6c621e3 100644
--- a/flang/include/flang/Parser/parse-tree-visitor.h
+++ b/flang/include/flang/Parser/parse-tree-visitor.h
@@ -16,6 +16,7 @@
 #include <tuple>
 #include <utility>
 #include <variant>
+#include <vector>
 
 /// Parse tree visitor
 /// Call Walk(x, visitor) to visit x and, by default, each node under x.
@@ -483,20 +484,76 @@ template <typename M> void Walk(CommonStmt &x, M &mutator) {
     mutator.Post(x);
   }
 }
+
+// Expr traversal uses iteration rather than recursion to avoid
+// blowing out the stack on very deep expression parse trees.
+// It replaces implementations that looked like:
+//   template <typename V> void Walk(const Expr &x, V visitor) {
+//     if (visitor.Pre(x)) {      // Pre on the Expr
+//       Walk(x.source, visitor);
+//       // Pre on the operator, walk the operands, Post on operator
+//       Walk(x.u, visitor);
+//       visitor.Post(x);         // Post on the Expr
+//     }
+//   }
+template <typename A, typename V, typename UNARY, typename BINARY>
+static void IterativeWalk(A &start, V &visitor) {
+  struct ExprWorkList {
+    ExprWorkList(A &x) : expr(&x) {}
+    bool doPostExpr{false}, doPostOpr{false};
+    A *expr;
+  };
+  std::vector<ExprWorkList> stack;
+  stack.emplace_back(start);
+  do {
+    A &expr{*stack.back().expr};
+    if (stack.back().doPostOpr) {
+      stack.back().doPostOpr = false;
+      common::visit([&visitor](auto &y) { visitor.Post(y); }, expr.u);
+    } else if (stack.back().doPostExpr) {
+      visitor.Post(expr);
+      stack.pop_back();
+    } else if (!visitor.Pre(expr)) {
+      stack.pop_back();
+    } else {
+      stack.back().doPostExpr = true;
+      Walk(expr.source, visitor);
+      UNARY *unary{nullptr};
+      BINARY *binary{nullptr};
+      common::visit(
+          [&unary, &binary](auto &y) {
+            if constexpr (std::is_convertible_v<decltype(&y), UNARY *>) {
+              unary = &y;
+            } else if constexpr (std::is_convertible_v<decltype(&y),
+                                     BINARY *>) {
+              binary = &y;
+            }
+          },
+          expr.u);
+      if (!unary && !binary) {
+        Walk(expr.u, visitor);
+      } else if (common::visit(
+                     [&visitor](auto &y) { return visitor.Pre(y); }, expr.u)) {
+        stack.back().doPostOpr = true;
+        if (unary) {
+          stack.emplace_back(unary->v.value());
+        } else {
+          stack.emplace_back(std::get<1>(binary->t).value());
+          stack.emplace_back(std::get<0>(binary->t).value());
+        }
+      }
+    }
+  } while (!stack.empty());
+}
 template <typename V> void Walk(const Expr &x, V &visitor) {
-  if (visitor.Pre(x)) {
-    Walk(x.source, visitor);
-    Walk(x.u, visitor);
-    visitor.Post(x);
-  }
+  IterativeWalk<const Expr, V, const Expr::IntrinsicUnary,
+      const Expr::IntrinsicBinary>(x, visitor);
 }
 template <typename M> void Walk(Expr &x, M &mutator) {
-  if (mutator.Pre(x)) {
-    Walk(x.source, mutator);
-    Walk(x.u, mutator);
-    mutator.Post(x);
-  }
+  IterativeWalk<Expr, M, Expr::IntrinsicUnary, Expr::IntrinsicBinary>(
+      x, mutator);
 }
+
 template <typename V> void Walk(const Designator &x, V &visitor) {
   if (visitor.Pre(x)) {
     Walk(x.source, visitor);

diff  --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h
index 1e56dded3547d..e8c313b9b9f38 100644
--- a/flang/include/flang/Semantics/expression.h
+++ b/flang/include/flang/Semantics/expression.h
@@ -381,6 +381,8 @@ class ExpressionAnalyzer {
   bool CheckIsValidForwardReference(const semantics::DerivedTypeSpec &);
   MaybeExpr AnalyzeComplex(MaybeExpr &&re, MaybeExpr &&im, const char *what);
 
+  MaybeExpr IterativelyAnalyzeSubexpressions(const parser::Expr &);
+
   semantics::SemanticsContext &context_;
   FoldingContext &foldingContext_{context_.foldingContext()};
   std::map<parser::CharBlock, int> impliedDos_; // values are INTEGER kinds
@@ -391,6 +393,7 @@ class ExpressionAnalyzer {
   bool inDataStmtObject_{false};
   bool inDataStmtConstant_{false};
   bool inStmtFunctionDefinition_{false};
+  bool iterativelyAnalyzingSubexpressions_{false};
   friend class ArgumentAnalyzer;
 };
 

diff  --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index af6cce3d5b2c3..b61b97a80554b 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -29,6 +29,7 @@
 #include <functional>
 #include <optional>
 #include <set>
+#include <vector>
 
 // Typedef for optional generic expressions (ubiquitous in this file)
 using MaybeExpr =
@@ -3326,6 +3327,12 @@ MaybeExpr ExpressionAnalyzer::ExprOrVariable(
     result = Analyze(x.u);
   }
   if (result) {
+    if constexpr (std::is_same_v<PARSED, parser::Expr>) {
+      if (!isNullPointerOk_ && IsNullPointer(*result)) {
+        Say(source,
+            "NULL() may not be used as an expression in this context"_err_en_US);
+      }
+    }
     SetExpr(x, Fold(std::move(*result)));
     return x.typedExpr->v;
   } else {
@@ -3341,15 +3348,76 @@ MaybeExpr ExpressionAnalyzer::ExprOrVariable(
   }
 }
 
+// This is an optional preliminary pass over parser::Expr subtrees.
+// Given an expression tree, iteratively traverse it in a bottom-up order
+// to analyze all of its subexpressions.  A later normal top-down analysis
+// will then be able to use the results that will have been saved in the
+// parse tree without having to recurse deeply.  This technique keeps
+// absurdly deep expression parse trees from causing the analyzer to overflow
+// its stack.
+MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions(
+    const parser::Expr &top) {
+  std::vector<const parser::Expr *> queue, finish;
+  queue.push_back(&top);
+  do {
+    const parser::Expr &expr{*queue.back()};
+    queue.pop_back();
+    if (!expr.typedExpr) {
+      const parser::Expr::IntrinsicUnary *unary{nullptr};
+      const parser::Expr::IntrinsicBinary *binary{nullptr};
+      common::visit(
+          [&unary, &binary](auto &y) {
+            if constexpr (std::is_convertible_v<decltype(&y),
+                              decltype(unary)>) {
+              // Don't evaluate a constant operand to Negate
+              if (!std::holds_alternative<parser::LiteralConstant>(
+                      y.v.value().u)) {
+                unary = &y;
+              }
+            } else if constexpr (std::is_convertible_v<decltype(&y),
+                                     decltype(binary)>) {
+              binary = &y;
+            }
+          },
+          expr.u);
+      if (unary) {
+        queue.push_back(&unary->v.value());
+      } else if (binary) {
+        queue.push_back(&std::get<0>(binary->t).value());
+        queue.push_back(&std::get<1>(binary->t).value());
+      }
+      finish.push_back(&expr);
+    }
+  } while (!queue.empty());
+  // Analyze the collected subexpressions in bottom-up order.
+  // On an error, bail out and leave partial results in place.
+  MaybeExpr result;
+  for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) {
+    const parser::Expr &expr{**riter};
+    result = ExprOrVariable(expr, expr.source);
+    if (!result) {
+      return result;
+    }
+  }
+  return result; // last value was from analysis of "top"
+}
+
 MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) {
-  if (useSavedTypedExprs_ && expr.typedExpr) {
-    return expr.typedExpr->v;
+  bool wasIterativelyAnalyzing{iterativelyAnalyzingSubexpressions_};
+  MaybeExpr result;
+  if (useSavedTypedExprs_) {
+    if (expr.typedExpr) {
+      return expr.typedExpr->v;
+    }
+    if (!wasIterativelyAnalyzing) {
+      iterativelyAnalyzingSubexpressions_ = true;
+      result = IterativelyAnalyzeSubexpressions(expr);
+    }
   }
-  MaybeExpr result{ExprOrVariable(expr, expr.source)};
-  if (!isNullPointerOk_ && result && IsNullPointer(*result)) {
-    Say(expr.source,
-        "NULL() may not be used as an expression in this context"_err_en_US);
+  if (!result) {
+    result = ExprOrVariable(expr, expr.source);
   }
+  iterativelyAnalyzingSubexpressions_ = wasIterativelyAnalyzing;
   return result;
 }
 
@@ -4017,7 +4085,7 @@ std::optional<ActualArgument> ArgumentAnalyzer::AnalyzeExpr(
     const parser::Expr &expr) {
   source_.ExtendToCover(expr.source);
   if (const Symbol *assumedTypeDummy{AssumedTypeDummy(expr)}) {
-    expr.typedExpr.Reset(new GenericExprWrapper{}, GenericExprWrapper::Deleter);
+    ResetExpr(expr);
     if (isProcedureCall_) {
       ActualArgument arg{ActualArgument::AssumedType{*assumedTypeDummy}};
       SetArgSourceLocation(arg, expr.source);

diff  --git a/flang/test/Evaluate/big-expr-tree.F90 b/flang/test/Evaluate/big-expr-tree.F90
new file mode 100644
index 0000000000000..feaa298723948
--- /dev/null
+++ b/flang/test/Evaluate/big-expr-tree.F90
@@ -0,0 +1,8 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Exercise parsing, expression analysis, and folding on a very tall expression tree
+! 32*32 = 1024 repetitions
+#define M0(x) x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x
+#define M1(x) x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x
+module m
+  logical, parameter :: test_1 = 32**2 .EQ. M1(M0(1))
+end module


        


More information about the flang-commits mailing list