[flang] [llvm] [flang][OpenMP] Overhaul implementation of ATOMIC construct (PR #137852)

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 11:50:23 PDT 2025


================
@@ -2656,527 +2665,1857 @@ void OmpStructureChecker::Leave(const parser::OmpEndBlockDirective &x) {
   }
 }
 
-inline void OmpStructureChecker::ErrIfAllocatableVariable(
-    const parser::Variable &var) {
-  // Err out if the given symbol has
-  // ALLOCATABLE attribute
-  if (const auto *e{GetExpr(context_, var)})
-    for (const Symbol &symbol : evaluate::CollectSymbols(*e))
-      if (IsAllocatable(symbol)) {
-        const auto &designator =
-            std::get<common::Indirection<parser::Designator>>(var.u);
-        const auto *dataRef =
-            std::get_if<parser::DataRef>(&designator.value().u);
-        const parser::Name *name =
-            dataRef ? std::get_if<parser::Name>(&dataRef->u) : nullptr;
-        if (name)
-          context_.Say(name->source,
-              "%s must not have ALLOCATABLE "
-              "attribute"_err_en_US,
-              name->ToString());
+/// parser::Block is a list of executable constructs, parser::BlockConstruct
+/// is Fortran's BLOCK/ENDBLOCK construct.
+/// Strip the outermost BlockConstructs, return the reference to the Block
+/// in the executable part of the innermost of the stripped constructs.
+/// Specifically, if the given `block` has a single entry (it's a list), and
+/// the entry is a BlockConstruct, get the Block contained within. Repeat
+/// this step as many times as possible.
+static const parser::Block &GetInnermostExecPart(const parser::Block &block) {
+  const parser::Block *iter{&block};
+  while (iter->size() == 1) {
+    const parser::ExecutionPartConstruct &ep{iter->front()};
+    if (auto *exec{std::get_if<parser::ExecutableConstruct>(&ep.u)}) {
+      using BlockConstruct = common::Indirection<parser::BlockConstruct>;
+      if (auto *bc{std::get_if<BlockConstruct>(&exec->u)}) {
+        iter = &std::get<parser::Block>(bc->value().t);
+        continue;
       }
+    }
+    break;
+  }
+  return *iter;
 }
 
-inline void OmpStructureChecker::ErrIfLHSAndRHSSymbolsMatch(
-    const parser::Variable &var, const parser::Expr &expr) {
-  // Err out if the symbol on the LHS is also used on the RHS of the assignment
-  // statement
-  const auto *e{GetExpr(context_, expr)};
-  const auto *v{GetExpr(context_, var)};
-  if (e && v) {
-    auto vSyms{evaluate::GetSymbolVector(*v)};
-    const Symbol &varSymbol = vSyms.front();
-    for (const Symbol &symbol : evaluate::GetSymbolVector(*e)) {
-      if (varSymbol == symbol) {
-        const common::Indirection<parser::Designator> *designator =
-            std::get_if<common::Indirection<parser::Designator>>(&expr.u);
-        if (designator) {
-          auto *z{var.typedExpr.get()};
-          auto *c{expr.typedExpr.get()};
-          if (z->v == c->v) {
-            context_.Say(expr.source,
-                "RHS expression on atomic assignment statement cannot access '%s'"_err_en_US,
-                var.GetSource());
-          }
-        } else {
-          context_.Say(expr.source,
-              "RHS expression on atomic assignment statement cannot access '%s'"_err_en_US,
-              var.GetSource());
-        }
-      }
+// There is no consistent way to get the source of a given ActionStmt, so
+// extract the source information from Statement<ActionStmt> when we can,
+// and keep it around for error reporting in further analyses.
+struct SourcedActionStmt {
+  const parser::ActionStmt *stmt{nullptr};
+  parser::CharBlock source;
+
+  operator bool() const { return stmt != nullptr; }
+};
+
+struct AnalyzedCondStmt {
+  SomeExpr cond{evaluate::NullPointer{}}; // Default ctor is deleted
+  parser::CharBlock source;
+  SourcedActionStmt ift, iff;
+};
+
+static SourcedActionStmt GetActionStmt(
+    const parser::ExecutionPartConstruct *x) {
+  if (x == nullptr) {
+    return SourcedActionStmt{};
+  }
+  if (auto *exec{std::get_if<parser::ExecutableConstruct>(&x->u)}) {
+    using ActionStmt = parser::Statement<parser::ActionStmt>;
+    if (auto *stmt{std::get_if<ActionStmt>(&exec->u)}) {
+      return SourcedActionStmt{&stmt->statement, stmt->source};
     }
   }
+  return SourcedActionStmt{};
 }
 
-inline void OmpStructureChecker::ErrIfNonScalarAssignmentStmt(
-    const parser::Variable &var, const parser::Expr &expr) {
-  // Err out if either the variable on the LHS or the expression on the RHS of
-  // the assignment statement are non-scalar (i.e. have rank > 0 or is of
-  // CHARACTER type)
-  const auto *e{GetExpr(context_, expr)};
-  const auto *v{GetExpr(context_, var)};
-  if (e && v) {
-    if (e->Rank() != 0 ||
-        (e->GetType().has_value() &&
-            e->GetType().value().category() == common::TypeCategory::Character))
-      context_.Say(expr.source,
-          "Expected scalar expression "
-          "on the RHS of atomic assignment "
-          "statement"_err_en_US);
-    if (v->Rank() != 0 ||
-        (v->GetType().has_value() &&
-            v->GetType()->category() == common::TypeCategory::Character))
-      context_.Say(var.GetSource(),
-          "Expected scalar variable "
-          "on the LHS of atomic assignment "
-          "statement"_err_en_US);
-  }
-}
-
-template <typename T, typename D>
-bool OmpStructureChecker::IsOperatorValid(const T &node, const D &variable) {
-  using AllowedBinaryOperators =
-      std::variant<parser::Expr::Add, parser::Expr::Multiply,
-          parser::Expr::Subtract, parser::Expr::Divide, parser::Expr::AND,
-          parser::Expr::OR, parser::Expr::EQV, parser::Expr::NEQV>;
-  using BinaryOperators = std::variant<parser::Expr::Add,
-      parser::Expr::Multiply, parser::Expr::Subtract, parser::Expr::Divide,
-      parser::Expr::AND, parser::Expr::OR, parser::Expr::EQV,
-      parser::Expr::NEQV, parser::Expr::Power, parser::Expr::Concat,
-      parser::Expr::LT, parser::Expr::LE, parser::Expr::EQ, parser::Expr::NE,
-      parser::Expr::GE, parser::Expr::GT>;
-
-  if constexpr (common::HasMember<T, BinaryOperators>) {
-    const auto &variableName{variable.GetSource().ToString()};
-    const auto &exprLeft{std::get<0>(node.t)};
-    const auto &exprRight{std::get<1>(node.t)};
-    if ((exprLeft.value().source.ToString() != variableName) &&
-        (exprRight.value().source.ToString() != variableName)) {
-      context_.Say(variable.GetSource(),
-          "Atomic update statement should be of form "
-          "`%s = %s operator expr` OR `%s = expr operator %s`"_err_en_US,
-          variableName, variableName, variableName, variableName);
-    }
-    return common::HasMember<T, AllowedBinaryOperators>;
+static SourcedActionStmt GetActionStmt(const parser::Block &block) {
+  if (block.size() == 1) {
+    return GetActionStmt(&block.front());
   }
-  return false;
+  return SourcedActionStmt{};
 }
 
-void OmpStructureChecker::CheckAtomicCaptureStmt(
-    const parser::AssignmentStmt &assignmentStmt) {
-  const auto &var{std::get<parser::Variable>(assignmentStmt.t)};
-  const auto &expr{std::get<parser::Expr>(assignmentStmt.t)};
-  common::visit(
-      common::visitors{
-          [&](const common::Indirection<parser::Designator> &designator) {
-            const auto *dataRef =
-                std::get_if<parser::DataRef>(&designator.value().u);
-            const auto *name =
-                dataRef ? std::get_if<parser::Name>(&dataRef->u) : nullptr;
-            if (name && IsAllocatable(*name->symbol))
-              context_.Say(name->source,
-                  "%s must not have ALLOCATABLE "
-                  "attribute"_err_en_US,
-                  name->ToString());
-          },
-          [&](const auto &) {
-            // Anything other than a `parser::Designator` is not allowed
-            context_.Say(expr.source,
-                "Expected scalar variable "
-                "of intrinsic type on RHS of atomic "
-                "assignment statement"_err_en_US);
-          }},
-      expr.u);
-  ErrIfLHSAndRHSSymbolsMatch(var, expr);
-  ErrIfNonScalarAssignmentStmt(var, expr);
-}
-
-void OmpStructureChecker::CheckAtomicWriteStmt(
-    const parser::AssignmentStmt &assignmentStmt) {
-  const auto &var{std::get<parser::Variable>(assignmentStmt.t)};
-  const auto &expr{std::get<parser::Expr>(assignmentStmt.t)};
-  ErrIfAllocatableVariable(var);
-  ErrIfLHSAndRHSSymbolsMatch(var, expr);
-  ErrIfNonScalarAssignmentStmt(var, expr);
-}
-
-void OmpStructureChecker::CheckAtomicUpdateStmt(
-    const parser::AssignmentStmt &assignment) {
-  const auto &expr{std::get<parser::Expr>(assignment.t)};
-  const auto &var{std::get<parser::Variable>(assignment.t)};
-  bool isIntrinsicProcedure{false};
-  bool isValidOperator{false};
-  common::visit(
-      common::visitors{
-          [&](const common::Indirection<parser::FunctionReference> &x) {
-            isIntrinsicProcedure = true;
-            const auto &procedureDesignator{
-                std::get<parser::ProcedureDesignator>(x.value().v.t)};
-            const parser::Name *name{
-                std::get_if<parser::Name>(&procedureDesignator.u)};
-            if (name &&
-                !(name->source == "max" || name->source == "min" ||
-                    name->source == "iand" || name->source == "ior" ||
-                    name->source == "ieor")) {
-              context_.Say(expr.source,
-                  "Invalid intrinsic procedure name in "
-                  "OpenMP ATOMIC (UPDATE) statement"_err_en_US);
-            }
-          },
-          [&](const auto &x) {
-            if (!IsOperatorValid(x, var)) {
-              context_.Say(expr.source,
-                  "Invalid or missing operator in atomic update "
-                  "statement"_err_en_US);
-            } else
-              isValidOperator = true;
-          },
-      },
-      expr.u);
-  if (const auto *e{GetExpr(context_, expr)}) {
-    const auto *v{GetExpr(context_, var)};
-    if (e->Rank() != 0 ||
-        (e->GetType().has_value() &&
-            e->GetType().value().category() == common::TypeCategory::Character))
-      context_.Say(expr.source,
-          "Expected scalar expression "
-          "on the RHS of atomic update assignment "
-          "statement"_err_en_US);
-    if (v->Rank() != 0 ||
-        (v->GetType().has_value() &&
-            v->GetType()->category() == common::TypeCategory::Character))
-      context_.Say(var.GetSource(),
-          "Expected scalar variable "
-          "on the LHS of atomic update assignment "
-          "statement"_err_en_US);
-    auto vSyms{evaluate::GetSymbolVector(*v)};
-    const Symbol &varSymbol = vSyms.front();
-    int numOfSymbolMatches{0};
-    SymbolVector exprSymbols{evaluate::GetSymbolVector(*e)};
-    for (const Symbol &symbol : exprSymbols) {
-      if (varSymbol == symbol) {
-        numOfSymbolMatches++;
-      }
-    }
-    if (isIntrinsicProcedure) {
-      std::string varName = var.GetSource().ToString();
-      if (numOfSymbolMatches != 1)
-        context_.Say(expr.source,
-            "Intrinsic procedure"
-            " arguments in atomic update statement"
-            " must have exactly one occurence of '%s'"_err_en_US,
-            varName);
-      else if (varSymbol != exprSymbols.front() &&
-          varSymbol != exprSymbols.back())
-        context_.Say(expr.source,
-            "Atomic update statement "
-            "should be of the form `%s = intrinsic_procedure(%s, expr_list)` "
-            "OR `%s = intrinsic_procedure(expr_list, %s)`"_err_en_US,
-            varName, varName, varName, varName);
-    } else if (isValidOperator) {
-      if (numOfSymbolMatches != 1)
-        context_.Say(expr.source,
-            "Exactly one occurence of '%s' "
-            "expected on the RHS of atomic update assignment statement"_err_en_US,
-            var.GetSource().ToString());
-    }
+// Compute the `evaluate::Assignment` from parser::ActionStmt. The assumption
+// is that the ActionStmt will be either an assignment or a pointer-assignment,
+// otherwise return std::nullopt.
+static std::optional<evaluate::Assignment> GetEvaluateAssignment(
+    const parser::ActionStmt *x) {
+  if (x == nullptr) {
+    return std::nullopt;
   }
 
-  ErrIfAllocatableVariable(var);
+  using AssignmentStmt = common::Indirection<parser::AssignmentStmt>;
+  using PointerAssignmentStmt =
+      common::Indirection<parser::PointerAssignmentStmt>;
+  using TypedAssignment = parser::AssignmentStmt::TypedAssignment;
+
+  return common::visit(
+      [](auto &&s) -> std::optional<evaluate::Assignment> {
+        using BareS = llvm::remove_cvref_t<decltype(s)>;
+        if constexpr (std::is_same_v<BareS, AssignmentStmt> ||
+            std::is_same_v<BareS, PointerAssignmentStmt>) {
+          const TypedAssignment &typed{s.value().typedAssignment};
+          // ForwardOwningPointer                 typedAssignment
+          // `- GenericAssignmentWrapper          ^.get()
+          //    `- std::optional<Assignment>      ^->v
+          return typed.get()->v;
+        } else {
+          return std::nullopt;
+        }
+      },
+      x->u);
 }
 
-void OmpStructureChecker::CheckAtomicCompareConstruct(
-    const parser::OmpAtomicCompare &atomicCompareConstruct) {
+static std::optional<AnalyzedCondStmt> AnalyzeConditionalStmt(
+    const parser::ExecutionPartConstruct *x) {
+  if (x == nullptr) {
+    return std::nullopt;
+  }
 
-  // TODO: Check that the if-stmt is `if (var == expr) var = new`
-  //       [with or without then/end-do]
+  // Extract the evaluate::Expr from ScalarLogicalExpr.
+  auto getFromLogical{[](const parser::ScalarLogicalExpr &logical) {
+    // ScalarLogicalExpr is Scalar<Logical<common::Indirection<Expr>>>
+    const parser::Expr &expr{logical.thing.thing.value()};
+    return GetEvaluateExpr(expr);
+  }};
 
-  unsigned version{context_.langOptions().OpenMPVersion};
-  if (version < 51) {
-    context_.Say(atomicCompareConstruct.source,
-        "%s construct not allowed in %s, %s"_err_en_US,
-        atomicCompareConstruct.source, ThisVersion(version), TryVersion(51));
-  }
-
-  // TODO: More work needed here. Some of the Update restrictions need to
-  // be added, but Update isn't the same either.
-}
-
-// TODO: Allow cond-update-stmt once compare clause is supported.
-void OmpStructureChecker::CheckAtomicCaptureConstruct(
-    const parser::OmpAtomicCapture &atomicCaptureConstruct) {
-  const parser::AssignmentStmt &stmt1 =
-      std::get<parser::OmpAtomicCapture::Stmt1>(atomicCaptureConstruct.t)
-          .v.statement;
-  const auto &stmt1Var{std::get<parser::Variable>(stmt1.t)};
-  const auto &stmt1Expr{std::get<parser::Expr>(stmt1.t)};
-  const auto *v1 = GetExpr(context_, stmt1Var);
-  const auto *e1 = GetExpr(context_, stmt1Expr);
-
-  const parser::AssignmentStmt &stmt2 =
-      std::get<parser::OmpAtomicCapture::Stmt2>(atomicCaptureConstruct.t)
-          .v.statement;
-  const auto &stmt2Var{std::get<parser::Variable>(stmt2.t)};
-  const auto &stmt2Expr{std::get<parser::Expr>(stmt2.t)};
-  const auto *v2 = GetExpr(context_, stmt2Var);
-  const auto *e2 = GetExpr(context_, stmt2Expr);
-
-  if (e1 && v1 && e2 && v2) {
-    if (semantics::checkForSingleVariableOnRHS(stmt1)) {
-      CheckAtomicCaptureStmt(stmt1);
-      if (semantics::checkForSymbolMatch(v2, e2)) {
-        // ATOMIC CAPTURE construct is of the form [capture-stmt, update-stmt]
-        CheckAtomicUpdateStmt(stmt2);
-      } else {
-        // ATOMIC CAPTURE construct is of the form [capture-stmt, write-stmt]
-        CheckAtomicWriteStmt(stmt2);
+  // Recognize either
+  // ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> IfStmt, or
+  // ExecutionPartConstruct -> ExecutableConstruct -> IfConstruct.
+
+  if (auto &&action{GetActionStmt(x)}) {
+    if (auto *ifs{std::get_if<common::Indirection<parser::IfStmt>>(
+            &action.stmt->u)}) {
+      const parser::IfStmt &s{ifs->value()};
+      auto &&maybeCond{
+          getFromLogical(std::get<parser::ScalarLogicalExpr>(s.t))};
+      auto &thenStmt{
+          std::get<parser::UnlabeledStatement<parser::ActionStmt>>(s.t)};
+      if (maybeCond) {
+        return AnalyzedCondStmt{std::move(*maybeCond), action.source,
+            SourcedActionStmt{&thenStmt.statement, thenStmt.source},
+            SourcedActionStmt{}};
       }
-      if (!(*e1 == *v2)) {
-        context_.Say(stmt1Expr.source,
-            "Captured variable/array element/derived-type component %s expected to be assigned in the second statement of ATOMIC CAPTURE construct"_err_en_US,
-            stmt1Expr.source);
-      }
-    } else if (semantics::checkForSymbolMatch(v1, e1) &&
-        semantics::checkForSingleVariableOnRHS(stmt2)) {
-      // ATOMIC CAPTURE construct is of the form [update-stmt, capture-stmt]
-      CheckAtomicUpdateStmt(stmt1);
-      CheckAtomicCaptureStmt(stmt2);
-      // Variable updated in stmt1 should be captured in stmt2
-      if (!(*v1 == *e2)) {
-        context_.Say(stmt1Var.GetSource(),
-            "Updated variable/array element/derived-type component %s expected to be captured in the second statement of ATOMIC CAPTURE construct"_err_en_US,
-            stmt1Var.GetSource());
-      }
-    } else {
-      context_.Say(stmt1Expr.source,
-          "Invalid ATOMIC CAPTURE construct statements. Expected one of [update-stmt, capture-stmt], [capture-stmt, update-stmt], or [capture-stmt, write-stmt]"_err_en_US);
     }
+    return std::nullopt;
   }
-}
 
-void OmpStructureChecker::CheckAtomicMemoryOrderClause(
-    const parser::OmpAtomicClauseList *leftHandClauseList,
-    const parser::OmpAtomicClauseList *rightHandClauseList) {
-  int numMemoryOrderClause{0};
-  int numFailClause{0};
-  auto checkForValidMemoryOrderClause = [&](const parser::OmpAtomicClauseList
-                                                *clauseList) {
-    for (const auto &clause : clauseList->v) {
-      if (std::get_if<parser::OmpFailClause>(&clause.u)) {
-        numFailClause++;
-        if (numFailClause > 1) {
-          context_.Say(clause.source,
-              "More than one FAIL clause not allowed on OpenMP ATOMIC construct"_err_en_US);
-          return;
+  if (auto *exec{std::get_if<parser::ExecutableConstruct>(&x->u)}) {
+    if (auto *ifc{
+            std::get_if<common::Indirection<parser::IfConstruct>>(&exec->u)}) {
+      using ElseBlock = parser::IfConstruct::ElseBlock;
+      using ElseIfBlock = parser::IfConstruct::ElseIfBlock;
+      const parser::IfConstruct &s{ifc->value()};
+
+      if (!std::get<std::list<ElseIfBlock>>(s.t).empty()) {
+        // Not expecting any else-if statements.
+        return std::nullopt;
+      }
+      auto &stmt{std::get<parser::Statement<parser::IfThenStmt>>(s.t)};
+      auto &&maybeCond{getFromLogical(
+          std::get<parser::ScalarLogicalExpr>(stmt.statement.t))};
+      if (!maybeCond) {
+        return std::nullopt;
+      }
+
+      if (auto &maybeElse{std::get<std::optional<ElseBlock>>(s.t)}) {
+        AnalyzedCondStmt result{std::move(*maybeCond), stmt.source,
+            GetActionStmt(std::get<parser::Block>(s.t)),
+            GetActionStmt(std::get<parser::Block>(maybeElse->t))};
+        if (result.ift.stmt && result.iff.stmt) {
+          return result;
         }
       } else {
-        if (std::get_if<parser::OmpMemoryOrderClause>(&clause.u)) {
-          numMemoryOrderClause++;
-          if (numMemoryOrderClause > 1) {
-            context_.Say(clause.source,
-                "More than one memory order clause not allowed on OpenMP ATOMIC construct"_err_en_US);
-            return;
-          }
+        AnalyzedCondStmt result{std::move(*maybeCond), stmt.source,
+            GetActionStmt(std::get<parser::Block>(s.t))};
+        if (result.ift.stmt) {
+          return result;
         }
       }
     }
-  };
-  if (leftHandClauseList) {
-    checkForValidMemoryOrderClause(leftHandClauseList);
-  }
-  if (rightHandClauseList) {
-    checkForValidMemoryOrderClause(rightHandClauseList);
+    return std::nullopt;
   }
-}
 
-void OmpStructureChecker::Enter(const parser::OpenMPAtomicConstruct &x) {
-  common::visit(
-      common::visitors{
-          [&](const parser::OmpAtomic &atomicConstruct) {
-            const auto &dir{std::get<parser::Verbatim>(atomicConstruct.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicUpdateStmt(
-                std::get<parser::Statement<parser::AssignmentStmt>>(
-                    atomicConstruct.t)
-                    .statement);
-            CheckAtomicMemoryOrderClause(
-                &std::get<parser::OmpAtomicClauseList>(atomicConstruct.t),
-                nullptr);
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<parser::OmpAtomicClauseList>(atomicConstruct.t),
-                nullptr, "ATOMIC");
-          },
-          [&](const parser::OmpAtomicUpdate &atomicUpdate) {
-            const auto &dir{std::get<parser::Verbatim>(atomicUpdate.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicUpdateStmt(
-                std::get<parser::Statement<parser::AssignmentStmt>>(
-                    atomicUpdate.t)
-                    .statement);
-            CheckAtomicMemoryOrderClause(
-                &std::get<0>(atomicUpdate.t), &std::get<2>(atomicUpdate.t));
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<0>(atomicUpdate.t), &std::get<2>(atomicUpdate.t),
-                "UPDATE");
-          },
-          [&](const parser::OmpAtomicRead &atomicRead) {
-            const auto &dir{std::get<parser::Verbatim>(atomicRead.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicMemoryOrderClause(
-                &std::get<0>(atomicRead.t), &std::get<2>(atomicRead.t));
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<0>(atomicRead.t), &std::get<2>(atomicRead.t), "READ");
-            CheckAtomicCaptureStmt(
-                std::get<parser::Statement<parser::AssignmentStmt>>(
-                    atomicRead.t)
-                    .statement);
-          },
-          [&](const parser::OmpAtomicWrite &atomicWrite) {
-            const auto &dir{std::get<parser::Verbatim>(atomicWrite.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicMemoryOrderClause(
-                &std::get<0>(atomicWrite.t), &std::get<2>(atomicWrite.t));
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<0>(atomicWrite.t), &std::get<2>(atomicWrite.t),
-                "WRITE");
-            CheckAtomicWriteStmt(
-                std::get<parser::Statement<parser::AssignmentStmt>>(
-                    atomicWrite.t)
-                    .statement);
-          },
-          [&](const parser::OmpAtomicCapture &atomicCapture) {
-            const auto &dir{std::get<parser::Verbatim>(atomicCapture.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicMemoryOrderClause(
-                &std::get<0>(atomicCapture.t), &std::get<2>(atomicCapture.t));
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<0>(atomicCapture.t), &std::get<2>(atomicCapture.t),
-                "CAPTURE");
-            CheckAtomicCaptureConstruct(atomicCapture);
-          },
-          [&](const parser::OmpAtomicCompare &atomicCompare) {
-            const auto &dir{std::get<parser::Verbatim>(atomicCompare.t)};
-            PushContextAndClauseSets(
-                dir.source, llvm::omp::Directive::OMPD_atomic);
-            CheckAtomicMemoryOrderClause(
-                &std::get<0>(atomicCompare.t), &std::get<2>(atomicCompare.t));
-            CheckHintClause<const parser::OmpAtomicClauseList>(
-                &std::get<0>(atomicCompare.t), &std::get<2>(atomicCompare.t),
-                "CAPTURE");
-            CheckAtomicCompareConstruct(atomicCompare);
-          },
-      },
-      x.u);
+  return std::nullopt;
 }
 
-void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) {
-  dirContext_.pop_back();
+static std::pair<parser::CharBlock, parser::CharBlock> SplitAssignmentSource(
+    parser::CharBlock source) {
+  // Find => in the range, if not found, find = that is not a part of
+  // <=, >=, ==, or /=.
+  auto trim{[](std::string_view v) {
+    const char *begin{v.data()};
+    const char *end{begin + v.size()};
+    while (*begin == ' ' && begin != end) {
+      ++begin;
+    }
+    while (begin != end && end[-1] == ' ') {
+      --end;
+    }
+    assert(begin != end && "Source should not be empty");
+    return parser::CharBlock(begin, end - begin);
+  }};
+
+  std::string_view sv(source.begin(), source.size());
+
+  if (auto where{sv.find("=>")}; where != sv.npos) {
+    std::string_view lhs(sv.data(), where);
+    std::string_view rhs(sv.data() + where + 2, sv.size() - where - 2);
+    return std::make_pair(trim(lhs), trim(rhs));
+  }
+
+  // Go backwards, since all the exclusions above end with a '='.
+  for (size_t next{source.size()}; next > 1; --next) {
+    if (sv[next - 1] == '=' && !llvm::is_contained("<>=/", sv[next - 2])) {
+      std::string_view lhs(sv.data(), next - 1);
+      std::string_view rhs(sv.data() + next, sv.size() - next);
+      return std::make_pair(trim(lhs), trim(rhs));
+    }
+  }
+  llvm_unreachable("Could not find assignment operator");
+}
+
+namespace atomic {
+
+template <typename V> static void MoveAppend(V &accum, V &&other) {
+  for (auto &&s : other) {
+    accum.push_back(std::move(s));
+  }
+}
+
+enum class Operator {
+  Unk,
+  // Operators that are officially allowed in the update operation
+  Add,
+  And,
+  Associated,
+  Div,
+  Eq,
+  Eqv,
+  Ge, // extension
+  Gt,
+  Identity, // extension: x = x is allowed (*), but we should never print
+            // "identity" as the name of the operator
+  Le, // extension
+  Lt,
+  Max,
+  Min,
+  Mul,
+  Ne, // extension
+  Neqv,
+  Or,
+  Sub,
+  // Operators that we recognize for technical reasons
+  True,
+  False,
+  Not,
+  Convert,
+  Resize,
+  Intrinsic,
+  Call,
+  Pow,
+
+  // (*): "x = x + 0" is a valid update statement, but it will be folded
+  //      to "x = x" by the time we look at it. Since the source statements
+  //      "x = x" and "x = x + 0" will end up looking the same, accept the
+  //      former as an extension.
+};
+
+std::string ToString(Operator op) {
+  switch (op) {
+  case Operator::Add:
+    return "+";
+  case Operator::And:
+    return "AND";
+  case Operator::Associated:
+    return "ASSOCIATED";
+  case Operator::Div:
+    return "/";
+  case Operator::Eq:
+    return "==";
+  case Operator::Eqv:
+    return "EQV";
+  case Operator::Ge:
+    return ">=";
+  case Operator::Gt:
+    return ">";
+  case Operator::Identity:
+    return "identity";
+  case Operator::Le:
+    return "<=";
+  case Operator::Lt:
+    return "<";
+  case Operator::Max:
+    return "MAX";
+  case Operator::Min:
+    return "MIN";
+  case Operator::Mul:
+    return "*";
+  case Operator::Neqv:
+    return "NEQV/EOR";
+  case Operator::Ne:
+    return "/=";
+  case Operator::Or:
+    return "OR";
+  case Operator::Sub:
+    return "-";
+  case Operator::True:
+    return ".TRUE.";
+  case Operator::False:
+    return ".FALSE.";
+  case Operator::Not:
+    return "NOT";
+  case Operator::Convert:
+    return "type-conversion";
+  case Operator::Resize:
+    return "resize";
+  case Operator::Intrinsic:
+    return "intrinsic";
+  case Operator::Call:
+    return "function-call";
+  case Operator::Pow:
+    return "**";
+  default:
+    return "??";
+  }
 }
 
-// Clauses
-// Mainly categorized as
-// 1. Checks on 'OmpClauseList' from 'parse-tree.h'.
-// 2. Checks on clauses which fall under 'struct OmpClause' from parse-tree.h.
-// 3. Checks on clauses which are not in 'struct OmpClause' from parse-tree.h.
+template <bool IgnoreResizingConverts> //
+struct ArgumentExtractor
+    : public evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+          std::pair<Operator, std::vector<SomeExpr>>, false> {
+  using Arguments = std::vector<SomeExpr>;
+  using Result = std::pair<Operator, Arguments>;
+  using Base = evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+      Result, false>;
+  static constexpr auto IgnoreResizes = IgnoreResizingConverts;
+  static constexpr auto Logical = common::TypeCategory::Logical;
+  ArgumentExtractor() : Base(*this) {}
 
-void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
-  // 2.7.1 Loop Construct Restriction
-  if (llvm::omp::allDoSet.test(GetContext().directive)) {
-    if (auto *clause{FindClause(llvm::omp::Clause::OMPC_schedule)}) {
-      // only one schedule clause is allowed
-      const auto &schedClause{std::get<parser::OmpClause::Schedule>(clause->u)};
-      auto &modifiers{OmpGetModifiers(schedClause.v)};
-      auto *ordering{
-          OmpGetUniqueModifier<parser::OmpOrderingModifier>(modifiers)};
-      if (ordering &&
-          ordering->v == parser::OmpOrderingModifier::Value::Nonmonotonic) {
-        if (FindClause(llvm::omp::Clause::OMPC_ordered)) {
-          context_.Say(clause->source,
-              "The NONMONOTONIC modifier cannot be specified "
-              "if an ORDERED clause is specified"_err_en_US);
-        }
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <int Kind> //
+  Result operator()(
+      const evaluate::Constant<evaluate::Type<Logical, Kind>> &x) const {
+    if (const auto &val{x.GetScalarValue()}) {
+      return val->IsTrue() ? std::make_pair(Operator::True, Arguments{})
+                           : std::make_pair(Operator::False, Arguments{});
+    }
+    return Default();
+  }
+
+  template <typename R> //
+  Result operator()(const evaluate::FunctionRef<R> &x) const {
+    Result result{OperationCode(x.proc()), {}};
+    for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+      if (auto *e{x.UnwrapArgExpr(i)}) {
+        result.second.push_back(*e);
       }
     }
+    return result;
+  }
 
-    if (auto *clause{FindClause(llvm::omp::Clause::OMPC_ordered)}) {
-      // only one ordered clause is allowed
-      const auto &orderedClause{
-          std::get<parser::OmpClause::Ordered>(clause->u)};
+  template <typename D, typename R, typename... Os>
+  Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
+      // Ignore top-level parentheses.
+      return (*this)(x.template operand<0>());
+    }
+    if constexpr (IgnoreResizes &&
+        std::is_same_v<D, evaluate::Convert<R, R::category>>) {
+      // Ignore conversions within the same category.
+      // Atomic operations on int(kind=1) may be implicitly widened
+      // to int(kind=4) for example.
+      return (*this)(x.template operand<0>());
+    } else {
+      return std::make_pair(
+          OperationCode(x), OperationArgs(x, std::index_sequence_for<Os...>{}));
+    }
+  }
 
-      if (orderedClause.v) {
-        CheckNotAllowedIfClause(
-            llvm::omp::Clause::OMPC_ordered, {llvm::omp::Clause::OMPC_linear});
+  template <typename T> //
+  Result operator()(const evaluate::Designator<T> &x) const {
+    evaluate::Designator<T> copy{x};
+    Result result{Operator::Identity, {AsGenericExpr(std::move(copy))}};
+    return result;
+  }
 
-        if (auto *clause2{FindClause(llvm::omp::Clause::OMPC_collapse)}) {
-          const auto &collapseClause{
-              std::get<parser::OmpClause::Collapse>(clause2->u)};
-          // ordered and collapse both have parameters
-          if (const auto orderedValue{GetIntValue(orderedClause.v)}) {
-            if (const auto collapseValue{GetIntValue(collapseClause.v)}) {
-              if (*orderedValue > 0 && *orderedValue < *collapseValue) {
-                context_.Say(clause->source,
-                    "The parameter of the ORDERED clause must be "
-                    "greater than or equal to "
-                    "the parameter of the COLLAPSE clause"_err_en_US);
-              }
-            }
-          }
-        }
+  template <typename... Rs> //
+  Result Combine(Result &&result, Rs &&...results) const {
+    // There shouldn't be any combining needed, since we're stopping the
+    // traversal at the top-level operation, but implement one that picks
+    // the first non-empty result.
+    if constexpr (sizeof...(Rs) == 0) {
+      return std::move(result);
+    } else {
+      if (!result.second.empty()) {
+        return std::move(result);
+      } else {
+        return Combine(std::move(results)...);
       }
+    }
+  }
 
-      // TODO: ordered region binding check (requires nesting implementation)
+private:
+  template <typename... Ts, int Kind>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op)
+      const {
+    switch (op.derived().logicalOperator) {
+    case common::LogicalOperator::And:
+      return Operator::And;
+    case common::LogicalOperator::Or:
+      return Operator::Or;
+    case common::LogicalOperator::Eqv:
+      return Operator::Eqv;
+    case common::LogicalOperator::Neqv:
+      return Operator::Neqv;
+    case common::LogicalOperator::Not:
+      return Operator::Not;
+    }
+    return Operator::Unk;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) const {
+    switch (op.derived().opr) {
+    case common::RelationalOperator::LT:
+      return Operator::Lt;
+    case common::RelationalOperator::LE:
+      return Operator::Le;
+    case common::RelationalOperator::EQ:
+      return Operator::Eq;
+    case common::RelationalOperator::NE:
+      return Operator::Ne;
+    case common::RelationalOperator::GE:
+      return Operator::Ge;
+    case common::RelationalOperator::GT:
+      return Operator::Gt;
+    }
+    return Operator::Unk;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Add<T>, Ts...> &op) const {
+    return Operator::Add;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) const {
+    return Operator::Sub;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) const {
+    return Operator::Mul;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) const {
+    return Operator::Div;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Power<T>, Ts...> &op) const {
+    return Operator::Pow;
+  }
+  template <typename T, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) const {
+    return Operator::Pow;
+  }
+  template <typename T, common::TypeCategory C, typename... Ts>
+  Operator OperationCode(
+      const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) const {
+    if constexpr (C == T::category) {
+      return Operator::Resize;
+    } else {
+      return Operator::Convert;
     }
-  } // doSet
+  }
+  Operator OperationCode(const evaluate::ProcedureDesignator &proc) const {
+    Operator code = llvm::StringSwitch<Operator>(proc.GetName())
+                        .Case("associated", Operator::Associated)
+                        .Case("min", Operator::Min)
+                        .Case("max", Operator::Max)
+                        .Case("iand", Operator::And)
+                        .Case("ior", Operator::Or)
+                        .Case("ieor", Operator::Neqv)
+                        .Default(Operator::Call);
+    if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+      return Operator::Intrinsic;
+    }
+    return code;
+  }
+  template <typename T> //
+  Operator OperationCode(const T &) const {
+    return Operator::Unk;
+  }
 
-  // 2.8.1 Simd Construct Restriction
-  if (llvm::omp::allSimdSet.test(GetContext().directive)) {
-    if (auto *clause{FindClause(llvm::omp::Clause::OMPC_simdlen)}) {
-      if (auto *clause2{FindClause(llvm::omp::Clause::OMPC_safelen)}) {
-        const auto &simdlenClause{
-            std::get<parser::OmpClause::Simdlen>(clause->u)};
-        const auto &safelenClause{
-            std::get<parser::OmpClause::Safelen>(clause2->u)};
-        // simdlen and safelen both have parameters
-        if (const auto simdlenValue{GetIntValue(simdlenClause.v)}) {
-          if (const auto safelenValue{GetIntValue(safelenClause.v)}) {
-            if (*safelenValue > 0 && *simdlenValue > *safelenValue) {
-              context_.Say(clause->source,
-                  "The parameter of the SIMDLEN clause must be less than or "
-                  "equal to the parameter of the SAFELEN clause"_err_en_US);
-            }
-          }
-        }
+  template <typename D, typename R, typename... Os, size_t... Is>
+  Arguments OperationArgs(const evaluate::Operation<D, R, Os...> &x,
+      std::index_sequence<Is...>) const {
+    return Arguments{SomeExpr(x.template operand<Is>())...};
+  }
+};
+
+struct DesignatorCollector : public evaluate::Traverse<DesignatorCollector,
+                                 std::vector<SomeExpr>, false> {
+  using Result = std::vector<SomeExpr>;
+  using Base = evaluate::Traverse<DesignatorCollector, Result, false>;
+  DesignatorCollector() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <typename T> //
+  Result operator()(const evaluate::Designator<T> &x) const {
+    // Once in a designator, don't traverse it any further (i.e. only
+    // collect top-level designators).
+    auto copy{x};
+    return Result{AsGenericExpr(std::move(copy))};
+  }
+
+  template <typename... Rs> //
+  Result Combine(Result &&result, Rs &&...results) const {
+    Result v(std::move(result));
+    (MoveAppend(v, std::move(results)), ...);
+    return v;
+  }
+};
+
+struct ConvertCollector
+    : public evaluate::Traverse<ConvertCollector,
+          std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>, false> {
+  using Result = std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>;
+  using Base = evaluate::Traverse<ConvertCollector, Result, false>;
+  ConvertCollector() : Base(*this) {}
+
+  Result Default() const { return {}; }
+
+  using Base::operator();
+
+  template <typename T> //
+  Result asSomeExpr(const T &x) const {
+    auto copy{x};
+    return {AsGenericExpr(std::move(copy)), {}};
+  }
+
+  template <typename T> //
+  Result operator()(const evaluate::Designator<T> &x) const {
+    return asSomeExpr(x);
+  }
+
+  template <typename T> //
+  Result operator()(const evaluate::FunctionRef<T> &x) const {
+    return asSomeExpr(x);
+  }
+
+  template <typename T> //
+  Result operator()(const evaluate::Constant<T> &x) const {
+    return asSomeExpr(x);
+  }
+
+  template <typename D, typename R, typename... Os>
+  Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
+    if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
+      // Ignore parentheses.
+      return (*this)(x.template operand<0>());
+    } else if constexpr (is_convert_v<D>) {
+      // Convert should always have a typed result, so it should be safe to
+      // dereference x.GetType().
+      return Combine(
+          {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+    } else if constexpr (is_complex_constructor_v<D>) {
+      // This is a conversion iff the imaginary operand is 0.
+      if (IsZero(x.template operand<1>())) {
+        return Combine(
+            {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+      } else {
+        return asSomeExpr(x.derived());
       }
+    } else {
+      return asSomeExpr(x.derived());
     }
+  }
 
-    // 2.11.5 Simd construct restriction (OpenMP 5.1)
-    if (auto *sl_clause{FindClause(llvm::omp::Clause::OMPC_safelen)}) {
-      if (auto *o_clause{FindClause(llvm::omp::Clause::OMPC_order)}) {
-        const auto &orderClause{
-            std::get<parser::OmpClause::Order>(o_clause->u)};
-        if (std::get<parser::OmpOrderClause::Ordering>(orderClause.v.t) ==
-            parser::OmpOrderClause::Ordering::Concurrent) {
-          context_.Say(sl_clause->source,
-              "The `SAFELEN` clause cannot appear in the `SIMD` directive "
-              "with `ORDER(CONCURRENT)` clause"_err_en_US);
-        }
+  template <typename... Rs> //
+  Result Combine(Result &&result, Rs &&...results) const {
+    Result v(std::move(result));
+    auto setValue{[](MaybeExpr &x, MaybeExpr &&y) {
+      assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+      if (!x.has_value()) {
+        x = std::move(y);
       }
+    }};
+    (setValue(v.first, std::move(results).first), ...);
+    (MoveAppend(v.second, std::move(results).second), ...);
+    return v;
+  }
+
+private:
+  template <typename T> //
+  static bool IsZero(const T &x) {
+    return false;
+  }
+  template <typename T> //
+  static bool IsZero(const evaluate::Expr<T> &x) {
+    return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+  }
+  template <typename T> //
+  static bool IsZero(const evaluate::Constant<T> &x) {
+    if (auto &&maybeScalar{x.GetScalarValue()}) {
+      return maybeScalar->IsZero();
+    } else {
+      return false;
     }
-  } // SIMD
+  }
 
-  // Semantic checks related to presence of multiple list items within the same
-  // clause
-  CheckMultListItems();
+  template <typename T> //
+  struct is_convert {
+    static constexpr bool value{false};
+  };
+  template <typename T, common::TypeCategory C> //
+  struct is_convert<evaluate::Convert<T, C>> {
+    static constexpr bool value{true};
+  };
+  template <int K> //
+  struct is_convert<evaluate::ComplexComponent<K>> {
+    // Conversion from complex to real.
+    static constexpr bool value{true};
+  };
+  template <typename T> //
+  static constexpr bool is_convert_v = is_convert<T>::value;
 
-  if (GetContext().directive == llvm::omp::Directive::OMPD_task) {
-    if (auto *detachClause{FindClause(llvm::omp::Clause::OMPC_detach)}) {
-      unsigned version{context_.langOptions().OpenMPVersion};
-      if (version == 50 || version == 51) {
+  template <typename T> //
+  struct is_complex_constructor {
+    static constexpr bool value{false};
+  };
+  template <int K> //
+  struct is_complex_constructor<evaluate::ComplexConstructor<K>> {
+    static constexpr bool value{true};
+  };
+  template <typename T> //
+  static constexpr bool is_complex_constructor_v =
+      is_complex_constructor<T>::value;
+};
+
+struct VariableFinder : public evaluate::AnyTraverse<VariableFinder> {
+  using Base = evaluate::AnyTraverse<VariableFinder>;
+  VariableFinder(const SomeExpr &v) : Base(*this), var(v) {}
+
+  using Base::operator();
+
+  template <typename T>
+  bool operator()(const evaluate::Designator<T> &x) const {
+    auto copy{x};
+    return evaluate::AsGenericExpr(std::move(copy)) == var;
+  }
+
+  template <typename T>
+  bool operator()(const evaluate::FunctionRef<T> &x) const {
+    auto copy{x};
+    return evaluate::AsGenericExpr(std::move(copy)) == var;
+  }
+
+private:
+  const SomeExpr &var;
+};
+} // namespace atomic
+
+static bool IsAllocatable(const SomeExpr &expr) {
+  std::vector<SomeExpr> dsgs{atomic::DesignatorCollector{}(expr)};
+  assert(dsgs.size() == 1 && "Should have a single top-level designator");
+  evaluate::SymbolVector syms{evaluate::GetSymbolVector(dsgs.front())};
+  return !syms.empty() && IsAllocatable(syms.back());
+}
+
+static std::pair<atomic::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
+    const SomeExpr &expr) {
+  return atomic::ArgumentExtractor<true>{}(expr);
+}
+
+std::vector<SomeExpr> GetOpenMPTopLevelArguments(const SomeExpr &expr) {
+  return GetTopLevelOperation(expr).second;
+}
+
+static bool IsPointerAssignment(const evaluate::Assignment &x) {
+  return std::holds_alternative<evaluate::Assignment::BoundsSpec>(x.u) ||
+      std::holds_alternative<evaluate::Assignment::BoundsRemapping>(x.u);
+}
+
+static bool IsCheckForAssociated(const SomeExpr &cond) {
+  return GetTopLevelOperation(cond).first == atomic::Operator::Associated;
+}
+
+static bool HasCommonDesignatorSymbols(
+    const evaluate::SymbolVector &baseSyms, const SomeExpr &other) {
+  // Compare the designators used in "other" with the designators whose
+  // symbols are given in baseSyms.
+  // This is a part of the check if these two expressions can access the same
+  // storage: if the designators used in them are different enough, then they
+  // will be assumed not to access the same memory.
+  //
+  // Consider an (array element) expression x%y(w%z), the corresponding symbol
+  // vector will be {x, y, w, z} (i.e. the symbols for these names).
+  // Check whether this exact sequence appears anywhere in any the symbol
+  // vector for "other". This will be true for x(y) and x(y+1), so this is
+  // not a sufficient condition, but can be used to eliminate candidates
+  // before doing more exhaustive checks.
+  //
+  // If any of the symbols in this sequence are function names, assume that
+  // there is no storage overlap, mostly because it would be impossible in
+  // general to determine what storage the function will access.
+  // Note: if f is pure, then two calls to f will access the same storage
+  // when called with the same arguments. This check is not done yet.
+
+  if (llvm::any_of(
+          baseSyms, [](const SymbolRef &s) { return s->IsSubprogram(); })) {
+    // If there is a function symbol in the chain then we can't infer much
+    // about the accessed storage.
+    return false;
+  }
+
+  auto isSubsequence{// Is u a subsequence of v.
+      [](const evaluate::SymbolVector &u, const evaluate::SymbolVector &v) {
+        size_t us{u.size()}, vs{v.size()};
+        if (us > vs) {
+          return false;
+        }
+        for (size_t off{0}; off != vs - us + 1; ++off) {
+          bool same{true};
+          for (size_t i{0}; i != us; ++i) {
+            if (u[i] != v[off + i]) {
+              same = false;
+              break;
+            }
+          }
+          if (same) {
+            return true;
+          }
+        }
+        return false;
+      }};
+
+  evaluate::SymbolVector otherSyms{evaluate::GetSymbolVector(other)};
+  return isSubsequence(baseSyms, otherSyms);
+}
+
+static bool HasCommonTopLevelDesignators(
+    const std::vector<SomeExpr> &baseDsgs, const SomeExpr &other) {
+  // Compare designators directly as expressions. This will ensure
+  // that x(y) and x(y+1) are not flagged as overlapping, whereas
+  // the symbol vectors for both of these would be identical.
+  std::vector<SomeExpr> otherDsgs{atomic::DesignatorCollector{}(other)};
+
+  for (auto &s : baseDsgs) {
+    if (llvm::any_of(otherDsgs, [&](auto &&t) { return s == t; })) {
+      return true;
+    }
+  }
+  return false;
+}
+
+static const SomeExpr *HasStorageOverlap(
+    const SomeExpr &base, llvm::ArrayRef<SomeExpr> exprs) {
+  evaluate::SymbolVector baseSyms{evaluate::GetSymbolVector(base)};
+  std::vector<SomeExpr> baseDsgs{atomic::DesignatorCollector{}(base)};
+
+  for (const SomeExpr &expr : exprs) {
+    if (!HasCommonDesignatorSymbols(baseSyms, expr)) {
+      continue;
+    }
+    if (HasCommonTopLevelDesignators(baseDsgs, expr)) {
+      return &expr;
+    }
+  }
+  return nullptr;
+}
+
+static bool IsMaybeAtomicWrite(const evaluate::Assignment &assign) {
+  // This ignores function calls, so it will accept "f(x) = f(x) + 1"
+  // for example.
+  return HasStorageOverlap(assign.lhs, assign.rhs) == nullptr;
+}
+
+MaybeExpr GetConvertInput(const SomeExpr &x) {
+  // This returns SomeExpr(x) when x is a designator/functionref/constant.
+  return atomic::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) {
+  // Check if expr is same as x, or a sequence of Convert operations on x.
+  if (expr == x) {
+    return true;
+  } else if (auto maybe{GetConvertInput(expr)}) {
+    return *maybe == x;
+  } else {
+    return false;
+  }
+}
+
+bool IsSubexpressionOf(const SomeExpr &sub, const SomeExpr &super) {
+  return atomic::VariableFinder{sub}(super);
+}
+
+static void SetExpr(parser::TypedExpr &expr, MaybeExpr value) {
+  if (value) {
+    expr.Reset(new evaluate::GenericExprWrapper(std::move(value)),
+        evaluate::GenericExprWrapper::Deleter);
+  }
+}
+
+static void SetAssignment(parser::AssignmentStmt::TypedAssignment &assign,
+    std::optional<evaluate::Assignment> value) {
+  if (value) {
+    assign.Reset(new evaluate::GenericAssignmentWrapper(std::move(value)),
+        evaluate::GenericAssignmentWrapper::Deleter);
+  }
+}
+
+static parser::OpenMPAtomicConstruct::Analysis::Op MakeAtomicAnalysisOp(
+    int what,
+    const std::optional<evaluate::Assignment> &maybeAssign = std::nullopt) {
+  parser::OpenMPAtomicConstruct::Analysis::Op operation;
+  operation.what = what;
+  SetAssignment(operation.assign, maybeAssign);
+  return operation;
+}
+
+static parser::OpenMPAtomicConstruct::Analysis MakeAtomicAnalysis(
+    const SomeExpr &atom, const MaybeExpr &cond,
+    parser::OpenMPAtomicConstruct::Analysis::Op &&op0,
+    parser::OpenMPAtomicConstruct::Analysis::Op &&op1) {
+  // Defined in flang/include/flang/Parser/parse-tree.h
+  //
+  // struct Analysis {
+  //   struct Kind {
+  //     static constexpr int None = 0;
+  //     static constexpr int Read = 1;
+  //     static constexpr int Write = 2;
+  //     static constexpr int Update = Read | Write;
+  //     static constexpr int Action = 3; // Bits containing N, R, W, U
+  //     static constexpr int IfTrue = 4;
+  //     static constexpr int IfFalse = 8;
+  //     static constexpr int Condition = 12; // Bits containing IfTrue, IfFalse
+  //   };
+  //   struct Op {
+  //     int what;
+  //     TypedAssignment assign;
+  //   };
+  //   TypedExpr atom, cond;
+  //   Op op0, op1;
+  // };
+
+  parser::OpenMPAtomicConstruct::Analysis an;
+  SetExpr(an.atom, atom);
+  SetExpr(an.cond, cond);
+  an.op0 = std::move(op0);
+  an.op1 = std::move(op1);
+  return an;
+}
+
+void OmpStructureChecker::CheckStorageOverlap(const SomeExpr &base,
+    llvm::ArrayRef<evaluate::Expr<evaluate::SomeType>> exprs,
+    parser::CharBlock source) {
+  if (auto *expr{HasStorageOverlap(base, exprs)}) {
+    context_.Say(source,
+        "Within atomic operation %s and %s access the same storage"_warn_en_US,
+        base.AsFortran(), expr->AsFortran());
+  }
+}
+
+void OmpStructureChecker::ErrorShouldBeVariable(
+    const MaybeExpr &expr, parser::CharBlock source) {
+  if (expr) {
+    context_.Say(source, "Atomic expression %s should be a variable"_err_en_US,
+        expr->AsFortran());
+  } else {
+    context_.Say(source, "Atomic expression should be a variable"_err_en_US);
+  }
+}
+
+/// Check if `expr` satisfies the following conditions for x and v:
+///
+/// [6.0:189:10-12]
+/// - x and v (as applicable) are either scalar variables or
+///   function references with scalar data pointer result of non-character
+///   intrinsic type or variables that are non-polymorphic scalar pointers
+///   and any length type parameter must be constant.
+void OmpStructureChecker::CheckAtomicVariable(
+    const SomeExpr &atom, parser::CharBlock source) {
+  if (atom.Rank() != 0) {
+    context_.Say(source, "Atomic variable %s should be a scalar"_err_en_US,
+        atom.AsFortran());
+  }
+
+  if (std::optional<evaluate::DynamicType> dtype{atom.GetType()}) {
+    if (dtype->category() == TypeCategory::Character) {
+      context_.Say(source,
+          "Atomic variable %s cannot have CHARACTER type"_err_en_US,
+          atom.AsFortran());
+    } else if (dtype->IsPolymorphic()) {
+      context_.Say(source,
+          "Atomic variable %s cannot have a polymorphic type"_err_en_US,
+          atom.AsFortran());
+    }
+    // TODO: Check non-constant type parameters for non-character types.
+    // At the moment there don't seem to be any.
+  }
+
+  if (IsAllocatable(atom)) {
+    context_.Say(source, "Atomic variable %s cannot be ALLOCATABLE"_err_en_US,
+        atom.AsFortran());
+  }
+}
+
+std::pair<const parser::ExecutionPartConstruct *,
+    const parser::ExecutionPartConstruct *>
+OmpStructureChecker::CheckUpdateCapture(
+    const parser::ExecutionPartConstruct *ec1,
+    const parser::ExecutionPartConstruct *ec2, parser::CharBlock source) {
+  // Decide which statement is the atomic update and which is the capture.
+  //
+  // The two allowed cases are:
+  //   x = ...      atomic-var = ...
+  //   ... = x      capture-var = atomic-var (with optional converts)
+  // or
+  //   ... = x      capture-var = atomic-var (with optional converts)
+  //   x = ...      atomic-var = ...
+  //
+  // The case of 'a = b; b = a' is ambiguous, so pick the first one as capture
+  // (which makes more sense, as it captures the original value of the atomic
+  // variable).
+  //
+  // If the two statements don't fit these criteria, return a pair of default-
+  // constructed values.
+  using ReturnTy = std::pair<const parser::ExecutionPartConstruct *,
+      const parser::ExecutionPartConstruct *>;
+
+  SourcedActionStmt act1{GetActionStmt(ec1)};
+  SourcedActionStmt act2{GetActionStmt(ec2)};
+  auto maybeAssign1{GetEvaluateAssignment(act1.stmt)};
+  auto maybeAssign2{GetEvaluateAssignment(act2.stmt)};
+  if (!maybeAssign1 || !maybeAssign2) {
+    context_.Say(source,
+        "ATOMIC UPDATE operation with CAPTURE should contain two assignments"_err_en_US);
+    return std::make_pair(nullptr, nullptr);
+  }
+
+  auto as1{*maybeAssign1}, as2{*maybeAssign2};
+
+  auto isUpdateCapture{
+      [](const evaluate::Assignment &u, const evaluate::Assignment &c) {
+        return IsSameOrConvertOf(c.rhs, u.lhs);
+      }};
+
+  // Do some checks that narrow down the possible choices for the update
+  // and the capture statements. This will help to emit better diagnostics.
+  // 1. An assignment could be an update (cbu) if the left-hand side is a
+  //    subexpression of the right-hand side.
+  // 2. An assignment could be a capture (cbc) if the right-hand side is
+  //    a variable (or a function ref), with potential type conversions.
+  bool cbu1{IsSubexpressionOf(as1.lhs, as1.rhs)};
+  bool cbu2{IsSubexpressionOf(as2.lhs, as2.rhs)};
+  bool cbc1{IsVarOrFunctionRef(GetConvertInput(as1.rhs))};
+  bool cbc2{IsVarOrFunctionRef(GetConvertInput(as2.rhs))};
+
+  //     |cbu1 cbu2|
+  // det |cbc1 cbc2| = cbu1*cbc2 - cbu2*cbc1
+  int det{int(cbu1) * int(cbc2) - int(cbu2) * int(cbc1)};
----------------
kparzysz wrote:

I added an explanation for the use of determinant instead.  This technique is used twice, and I think the explanation makes it clear.  Let me know what you think.

https://github.com/llvm/llvm-project/pull/137852


More information about the llvm-commits mailing list