[flang-commits] [flang] [flang][OpenMP] Improve reduction of Scalar ArrayElement types (PR #163940)

Jack Styles via flang-commits flang-commits at lists.llvm.org
Mon Oct 20 03:57:59 PDT 2025


https://github.com/Stylie777 updated https://github.com/llvm/llvm-project/pull/163940

>From 6233d8df01cc8cc580debfe4a991cf0c6240ff61 Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Thu, 2 Oct 2025 16:52:43 +0100
Subject: [PATCH 1/2] [flang][OpenMP] Fix reduction of Scalar ArrayElement
 types

Currently, Flang does not correctly lower ArrayElement's
when processing an OpenMP Reduction Clause correctly. Rather
than lowering the array element, the whole array will be lowered.
This leads to slower performance for the end user in their program.

This patch works to rectify this by rewriteing the parse tree
while processing semantics. The use of an ArrayElement in an
OpenMP Reduction Clause will be identified, and replaced with a
temporary both in the reduction clause, and anywhere that array
element is used within the respective DoConstruct. Once the
DoConstruct has finished, if the ArrayElement has been used within
the Do loop, the value of the temporary will be re-assigned to
the array element. One limitation of this approach is that if the
ArrayElement is not used, there is no available element in the parse
tree to use to reassign the value, so its only done if used.

The reason for making the change in the parse tree is due to how
ArrayElements are lowered. When lowering, the expression of the
ArrayElement being used in the reduction is being substitued with
the reference to the symbol. In this case, that would be the whole
array. By replacing it with a temporary, it removes the issue
of lowering a full array as it will be referencing the
temporary instead. To address this in lowering would require a
major rethink on how a considerable amount of non-OpenMP code
is lowered and as such, not deemed the appropriate course of
action for this specific case.

This process is done after the initial Semantics Pass as to
not affect the checking of users original code. If the array
element has been replaced, the first pass of semantics will
need to be rerun to ensure all TypedExpr's are correctly
captured otherwise the lowering will not function correctly.
This step is only done if an ArrayElement is replaced.

Testing is covered by reduction17.f90. This checks both the
parse tree, unparsing and HLFIR to ensure the temproary is being used in
the reduction clause and Do loop. Assignment to, and reassignment from
the ArrayElement and the Temporary is also considered to ensure this is
inserted at the correct location.

reduction09.f90 has also been reformatted to rely on FileCheck.
As the Parse Tree is changing, the output is different to that of
the user, so we can no longer rely on test_symbols.py for this test.
The same information is being checked, with test cases that cover using
an ArrayElement in the Do loop, and not using the ArrayElement being
covered.

Array Sections are not affected by this change, only uses of
single array elements.
---
 flang/lib/Semantics/rewrite-parse-tree.cpp  | 524 ++++++++++++++++++++
 flang/lib/Semantics/rewrite-parse-tree.h    |   2 +
 flang/lib/Semantics/semantics.cpp           |  13 +-
 flang/test/Semantics/OpenMP/reduction09.f90 | 109 +++-
 flang/test/Semantics/OpenMP/reduction17.f90 | 209 ++++++++
 5 files changed, 837 insertions(+), 20 deletions(-)
 create mode 100644 flang/test/Semantics/OpenMP/reduction17.f90

diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 5b7dab309eda7..5379dcdd3d40c 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -95,6 +95,86 @@ class RewriteMutator {
   parser::Messages &messages_;
 };
 
+class ReplacementTemp {
+public:
+  ReplacementTemp() {}
+
+  void createTempSymbol(
+      SourceName &source, Scope &scope, SemanticsContext &context);
+  void setOriginalSubscriptInt(
+      std::list<parser::SectionSubscript> &sectionSubscript);
+  Symbol *getTempSymbol() { return replacementTempSymbol_; }
+  Symbol *getPrivateReductionSymbol() { return privateReductionSymbol_; }
+  parser::CharBlock getOriginalSource() { return originalSource_; }
+  parser::Name getOriginalName() { return originalName_; }
+  parser::CharBlock getOriginalSubscript() {
+    return originalSubscriptCharBlock_;
+  }
+  Scope *getTempScope() { return tempScope_; }
+  bool isArrayElementReassigned() { return arrayElementReassigned_; }
+  bool isSectionTriplet() { return isSectionTriplet_; }
+  void arrayElementReassigned() { arrayElementReassigned_ = true; }
+  void setOriginalName(parser::Name &name) {
+    originalName_ = common::Clone(name);
+  }
+  void setOriginalSource(parser::CharBlock &source) {
+    originalSource_ = source;
+  }
+  void setOriginalSubscriptInt(parser::CharBlock &subscript) {
+    originalSubscriptCharBlock_ = subscript;
+  }
+  void setTempScope(Scope &scope) { tempScope_ = &scope; };
+  void setTempSymbol(Symbol *symbol) { replacementTempSymbol_ = symbol; }
+
+private:
+  Symbol *replacementTempSymbol_{nullptr};
+  Symbol *privateReductionSymbol_{nullptr};
+  Scope *tempScope_{nullptr};
+  parser::CharBlock originalSource_;
+  parser::Name originalName_;
+  parser::CharBlock originalSubscriptCharBlock_;
+  bool arrayElementReassigned_{false};
+  bool isSectionTriplet_{false};
+};
+
+class RewriteOmpReductionArrayElements {
+public:
+  explicit RewriteOmpReductionArrayElements(SemanticsContext &context)
+      : context_(context) {}
+  // Default action for a parse tree node is to visit children.
+  template <typename T> bool Pre(T &) { return true; }
+  template <typename T> void Post(T &) {}
+
+  void Post(parser::Block &block);
+  void Post(parser::Variable &var);
+  void Post(parser::Expr &expr);
+  void Post(parser::AssignmentStmt &assignmentStmt);
+  void Post(parser::PointerAssignmentStmt &ptrAssignmentStmt);
+  void rewriteReductionArrayElementToTemp(parser::Block &block);
+  bool isArrayElementRewritten() { return arrayElementReassigned_; }
+
+private:
+  bool isMatchingArrayElement(parser::Designator &existingDesignator);
+  template <typename T>
+  void processFunctionReference(
+      T &node, parser::CharBlock source, parser::FunctionReference &funcRef);
+  parser::Designator makeTempDesignator(parser::CharBlock source);
+  bool rewriteArrayElementToTemp(parser::Block::iterator &it,
+      parser::OpenMPLoopConstruct &ompLoop, parser::Block &block,
+      ReplacementTemp &temp);
+  bool identifyArrayElementReduced(
+      parser::Designator &designator, ReplacementTemp &temp);
+  void reassignTempValueToArrayElement(parser::ArrayElement &arrayElement);
+  void setCurrentTemp(ReplacementTemp *temp) { currentTemp_ = temp; }
+  void resetCurrentTemp() { currentTemp_ = nullptr; }
+
+  SemanticsContext &context_;
+  bool arrayElementReassigned_{false};
+  parser::Block::iterator reassignmentInsertionPoint_;
+  parser::Block *block_{nullptr};
+  ReplacementTemp *currentTemp_{nullptr};
+};
+
 // Check that name has been resolved to a symbol
 void RewriteMutator::Post(parser::Name &name) {
   if (!name.symbol && errorOnUnresolvedName_) {
@@ -492,10 +572,454 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
   FixMisparsedUntaggedNamelistName(x);
 }
 
+void ReplacementTemp::createTempSymbol(
+    SourceName &source, Scope &scope, SemanticsContext &context) {
+  replacementTempSymbol_ =
+      const_cast<semantics::Scope &>(originalName_.symbol->owner())
+          .FindSymbol(source);
+  replacementTempSymbol_->set_scope(
+      &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+  DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+  replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+  replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+    std::list<parser::SectionSubscript> &sectionSubscript) {
+  bool setSubscript{false};
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) {
+                     parser::Expr &expr = intExpr.thing.value();
+                     std::visit(
+                         llvm::makeVisitor(
+                             [&](parser::LiteralConstant &literalContant) {
+                               std::visit(llvm::makeVisitor(
+                                              [&](parser::IntLiteralConstant
+                                                      &intLiteralConstant) {
+                                                originalSubscriptCharBlock_ =
+                                                    std::get<parser::CharBlock>(
+                                                        intLiteralConstant.t);
+                                                setSubscript = true;
+                                              },
+                                              [&](auto &) {}),
+                                   literalContant.u);
+                             },
+                             [&](auto &) {}),
+                         expr.u);
+                   },
+                   [&](parser::SubscriptTriplet &triplet) {
+                     isSectionTriplet_ = true;
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        subscript.u);
+    if (setSubscript) {
+      break;
+    }
+  }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+    parser::Block &block) {
+  if (block.empty()) {
+    return;
+  }
+
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](parser::ExecutableConstruct &execConstruct) {
+              std::visit(
+                  llvm::makeVisitor(
+                      [&](common::Indirection<parser::OpenMPConstruct>
+                              &ompConstruct) {
+                        std::visit(
+                            llvm::makeVisitor(
+                                [&](parser::OpenMPLoopConstruct &ompLoop) {
+                                  ReplacementTemp temp;
+                                  if (!rewriteArrayElementToTemp(
+                                          it, ompLoop, block, temp)) {
+                                    return;
+                                  }
+                                  auto &NestedConstruct = std::get<
+                                      std::optional<parser::NestedConstruct>>(
+                                      ompLoop.t);
+                                  if (!NestedConstruct.has_value()) {
+                                    return;
+                                  }
+                                  if (parser::DoConstruct *
+                                      doConst{std::get_if<parser::DoConstruct>(
+                                          &NestedConstruct.value())}) {
+                                    block_ = █
+                                    parser::Block &doBlock{
+                                        std::get<parser::Block>(doConst->t)};
+                                    parser::Walk(doBlock, *this);
+                                    // Reset the current temp value so future
+                                    // iterations use their own version.
+                                    resetCurrentTemp();
+                                  }
+                                },
+                                [&](auto &) {}),
+                            ompConstruct.value().u);
+                      },
+                      [&](auto &) {}),
+                  execConstruct.u);
+            },
+            [&](auto &) {}),
+        it->u);
+  }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+    parser::Designator &existingDesignator) {
+  bool matchesArrayElement{false};
+  std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+  std::visit(llvm::makeVisitor(
+                 [&](parser::DataRef &dataRef) {
+                   std::visit(
+                       llvm::makeVisitor(
+                           [&](common::Indirection<parser::ArrayElement>
+                                   &arrayElement) {
+                             subscripts = &arrayElement.value().subscripts;
+                             std::visit(
+                                 llvm::makeVisitor(
+                                     [&](parser::Name &name) {
+                                       if (name.symbol->GetUltimate() ==
+                                           currentTemp_->getOriginalName()
+                                               .symbol->GetUltimate()) {
+                                         matchesArrayElement = true;
+                                         if (!currentTemp_
+                                                 ->isArrayElementReassigned()) {
+                                           reassignTempValueToArrayElement(
+                                               arrayElement.value());
+                                         }
+                                       }
+                                     },
+                                     [](auto &) {}),
+                                 arrayElement.value().base.u);
+                           },
+                           [&](parser::Name &name) {
+                             if (name.symbol->GetUltimate() ==
+                                 currentTemp_->getOriginalName()
+                                     .symbol->GetUltimate()) {
+                               matchesArrayElement = true;
+                             }
+                           },
+                           [](auto &) {}),
+                       dataRef.u);
+                 },
+                 [&](auto &) {}),
+      existingDesignator.u);
+
+  if (subscripts) {
+    bool foundSubscript{false};
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement = std::visit(
+          llvm::makeVisitor(
+              [&](parser::IntExpr &intExpr) -> bool {
+                parser::Expr &expr = intExpr.thing.value();
+                return std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::LiteralConstant &literalContant) -> bool {
+                          return std::visit(
+                              llvm::makeVisitor(
+                                  [&](parser::IntLiteralConstant
+                                          &intLiteralConstant) -> bool {
+                                    foundSubscript = true;
+                                    assert(currentTemp_ != nullptr &&
+                                        "Value for ReplacementTemp should have "
+                                        "been found");
+                                    if (std::get<parser::CharBlock>(
+                                            intLiteralConstant.t) ==
+                                        currentTemp_->getOriginalSubscript()) {
+                                      return true;
+                                    }
+                                    return false;
+                                  },
+                                  [](auto &) -> bool { return false; }),
+                              literalContant.u);
+                        },
+                        [](auto &) -> bool { return false; }),
+                    expr.u);
+              },
+              [](auto &) -> bool { return false; }),
+          subscript.u);
+      if (foundSubscript) {
+        break;
+      }
+    }
+  }
+  return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+    T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+  std::optional<parser::Designator> arrayElementDesignator =
+      std::visit(llvm::makeVisitor(
+                     [&](parser::Name &functionReferenceName)
+                         -> std::optional<parser::Designator> {
+                       if (currentTemp_->getOriginalName().symbol ==
+                           functionReferenceName.symbol) {
+                         return funcRef.ConvertToArrayElementRef();
+                       }
+                       return std::nullopt;
+                     },
+                     [&](auto &) -> std::optional<parser::Designator> {
+                       return std::nullopt;
+                     }),
+          procedureDesignator.u);
+
+  if (arrayElementDesignator.has_value()) {
+    if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+      node = T{
+          common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+    }
+  }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+    parser::CharBlock source) {
+  parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+  tempVariableName.symbol = currentTemp_->getTempSymbol();
+  parser::Designator tempDesignator{
+      parser::DataRef{std::move(tempVariableName)}};
+  tempDesignator.source = source;
+  return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+    parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+    parser::Block &block, ReplacementTemp &temp) {
+  parser::OmpBeginLoopDirective &ompBeginLoop{
+      std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+  std::list<parser::OmpClause> &clauseList{
+      std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
+  bool rewrittenArrayElement{false};
+
+  for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+    rewrittenArrayElement = std::visit(
+        llvm::makeVisitor(
+            [&](parser::OmpClause::Reduction &clause) -> bool {
+              std::list<parser::OmpObject> &objectList =
+                  std::get<parser::OmpObjectList>(clause.v.t).v;
+
+              bool rewritten{false};
+              for (auto object{objectList.begin()}; object != objectList.end();
+                  ++object) {
+                rewritten = std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::Designator &designator) -> bool {
+                          if (!identifyArrayElementReduced(designator, temp)) {
+                            return false;
+                          }
+                          if (temp.isSectionTriplet()) {
+                            return false;
+                          }
+
+                          reassignmentInsertionPoint_ =
+                              it != block.end() ? it : block.end();
+                          std::string tempSourceString = "reduction_temp_" +
+                              temp.getOriginalSource().ToString() + "(" +
+                              temp.getOriginalSubscript().ToString() + ")";
+                          SourceName source = context_.SaveTempName(
+                              std::move(tempSourceString));
+                          Scope &scope = const_cast<Scope &>(
+                              temp.getOriginalName().symbol->owner());
+                          if (Symbol * symbol{scope.FindSymbol(source)}) {
+                            temp.setTempSymbol(symbol);
+                          } else {
+                            if (scope
+                                    .try_emplace(source, semantics::Attrs{},
+                                        semantics::ObjectEntityDetails{})
+                                    .second) {
+                              temp.createTempSymbol(source, scope, context_);
+                            } else {
+                              common::die("Failed to create temp symbol for %s",
+                                  source.ToString().c_str());
+                            }
+                          }
+                          setCurrentTemp(&temp);
+                          temp.setTempScope(scope);
+
+                          // Assign the value of the array element to the
+                          // temporary variable
+                          parser::Variable newVariable{
+                              makeTempDesignator(temp.getOriginalSource())};
+                          parser::Expr newExpr{
+                              common::Indirection<parser::Designator>{
+                                  std::move(designator)}};
+                          newExpr.source = temp.getOriginalSource();
+                          std::tuple<parser::Variable, parser::Expr> newT{
+                              std::move(newVariable), std::move(newExpr)};
+                          parser::AssignmentStmt assignment{std::move(newT)};
+                          parser::ExecutionPartConstruct
+                              tempVariablePartConstruct{
+                                  parser::ExecutionPartConstruct{
+                                      parser::ExecutableConstruct{
+                                          parser::Statement<parser::ActionStmt>{
+                                              std::optional<parser::Label>{},
+                                              std::move(assignment)}}}};
+                          block.insert(
+                              it, std::move(tempVariablePartConstruct));
+                          arrayElementReassigned_ = true;
+
+                          designator =
+                              makeTempDesignator(temp.getOriginalSource());
+                          return true;
+                        },
+                        [&](const auto &) -> bool { return false; }),
+                    object->u);
+              }
+              return rewritten;
+            },
+            [&](auto &) -> bool { return false; }),
+        iter->u);
+
+    if (rewrittenArrayElement) {
+      return rewrittenArrayElement;
+    }
+  }
+  return rewrittenArrayElement;
+}
+
+bool RewriteOmpReductionArrayElements::identifyArrayElementReduced(
+    parser::Designator &designator, ReplacementTemp &temp) {
+  return std::visit(
+      llvm::makeVisitor(
+          [&](parser::DataRef &dataRef) -> bool {
+            return std::visit(
+                llvm::makeVisitor(
+                    [&](common::Indirection<parser::ArrayElement>
+                            &arrayElement) {
+                      std::visit(llvm::makeVisitor(
+                                     [&](parser::Name &name) -> void {
+                                       temp.setOriginalName(name);
+                                       temp.setOriginalSource(name.source);
+                                     },
+                                     [&](auto &) -> void {}),
+                          arrayElement.value().base.u);
+                      temp.setOriginalSubscriptInt(
+                          arrayElement.value().subscripts);
+                      return !temp.isSectionTriplet() ? true : false;
+                    },
+                    [&](auto &) -> bool { return false; }),
+                dataRef.u);
+          },
+          [&](auto &) -> bool { return false; }),
+      designator.u);
+}
+
+void RewriteOmpReductionArrayElements::reassignTempValueToArrayElement(
+    parser::ArrayElement &arrayElement) {
+  assert(block_ && "Need iterator to reassign");
+  parser::CharBlock originalSource = currentTemp_->getOriginalSource();
+  parser::DataRef reassignmentDataRef{std::move(arrayElement)};
+  common::Indirection<parser::Designator> arrayElementDesignator{
+      std::move(reassignmentDataRef)};
+  arrayElementDesignator.value().source = originalSource;
+  parser::Variable exisitingVar{std::move(arrayElementDesignator)};
+  std::get<common::Indirection<parser::Designator>>(exisitingVar.u)
+      .value()
+      .source = originalSource;
+  parser::Expr reassignmentExpr{makeTempDesignator(originalSource)};
+  SourceName source{"reductionTemp"};
+  reassignmentExpr.source = source;
+  std::tuple<parser::Variable, parser::Expr> reassignment{
+      std::move(exisitingVar), std::move(reassignmentExpr)};
+  parser::AssignmentStmt reassignStmt{std::move(reassignment)};
+  parser::ExecutionPartConstruct tempVariableReassignment{
+      parser::ExecutionPartConstruct{
+          parser::ExecutableConstruct{parser::Statement<parser::ActionStmt>{
+              std::optional<parser::Label>{}, std::move(reassignStmt)}}}};
+  block_->insert(std::next(reassignmentInsertionPoint_),
+      std::move(tempVariableReassignment));
+  currentTemp_->arrayElementReassigned();
+}
+
+void RewriteOmpReductionArrayElements::Post(
+    parser::AssignmentStmt &assignmentStmt) {
+  if (arrayElementReassigned_) {
+    // The typed expression needs to be reset where we are reassigning array
+    // elements so the semantics can regenerate the expressions correctly.
+    assignmentStmt.typedAssignment.Reset();
+  }
+}
+void RewriteOmpReductionArrayElements::Post(
+    parser::PointerAssignmentStmt &ptrAssignmentStmt) {
+  if (arrayElementReassigned_) {
+    // The typed expression needs to be reset where we are reassigning array
+    // elements so the semantics can regenerate the expressions correctly.
+    ptrAssignmentStmt.typedAssignment.Reset();
+  }
+}
+void RewriteOmpReductionArrayElements::Post(parser::Variable &var) {
+  if (currentTemp_) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::FunctionReference> &funcRef)
+                -> void {
+              this->processFunctionReference<parser::Variable>(
+                  var, var.GetSource(), funcRef.value());
+            },
+            [&](common::Indirection<parser::Designator> &designator) -> void {
+              if (isMatchingArrayElement(designator.value())) {
+                designator = makeTempDesignator(designator.value().source);
+                var = parser::Variable{std::move(designator)};
+              }
+            },
+            [&](auto &) -> void {}),
+        var.u);
+  }
+  if (arrayElementReassigned_) {
+    // The typed expression needs to be reset where we are reassigning array
+    // elements so the semantics can regenerate the expressions correctly.
+    var.typedExpr.Reset();
+  }
+}
+void RewriteOmpReductionArrayElements::Post(parser::Expr &expr) {
+  if (currentTemp_) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::FunctionReference> &funcRef)
+                -> void {
+              this->processFunctionReference<parser::Expr>(
+                  expr, expr.source, funcRef.value());
+            },
+            [&](common::Indirection<parser::Designator> &designator) -> void {
+              if (isMatchingArrayElement(designator.value())) {
+                designator = makeTempDesignator(designator.value().source);
+                expr = parser::Expr{std::move(designator)};
+              }
+            },
+            [&](auto &) {}),
+        expr.u);
+  }
+  if (arrayElementReassigned_) {
+    // The typed expression needs to be reset where we are reassigning array
+    // elements so the semantics can regenerate the expressions correctly.
+    expr.typedExpr.Reset();
+  }
+}
+
+void RewriteOmpReductionArrayElements::Post(parser::Block &block) {
+  rewriteReductionArrayElementToTemp(block);
+}
+
 bool RewriteParseTree(SemanticsContext &context, parser::Program &program) {
   RewriteMutator mutator{context};
   parser::Walk(program, mutator);
   return !context.AnyFatalError();
 }
 
+bool RewriteReductionArrayElements(
+    SemanticsContext &context, parser::Program &program) {
+  RewriteOmpReductionArrayElements mutator{context};
+  parser::Walk(program, mutator);
+  return mutator.isArrayElementRewritten();
+}
+
 } // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/rewrite-parse-tree.h b/flang/lib/Semantics/rewrite-parse-tree.h
index 313276db481d5..aef953aa5b9e8 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.h
+++ b/flang/lib/Semantics/rewrite-parse-tree.h
@@ -19,6 +19,8 @@ class SemanticsContext;
 
 namespace Fortran::semantics {
 bool RewriteParseTree(SemanticsContext &, parser::Program &);
+bool RewriteReductionArrayElements(
+    SemanticsContext &context, parser::Program &program);
 }
 
 #endif // FORTRAN_SEMANTICS_REWRITE_PARSE_TREE_H_
diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp
index bdb5377265c14..678f6e6d0ce5c 100644
--- a/flang/lib/Semantics/semantics.cpp
+++ b/flang/lib/Semantics/semantics.cpp
@@ -209,7 +209,8 @@ static bool PerformStatementSemantics(
   RewriteParseTree(context, program);
   ComputeOffsets(context, context.globalScope());
   CheckDeclarations(context);
-  StatementSemanticsPass1{context}.Walk(program);
+  StatementSemanticsPass1 pass1{context};
+  pass1.Walk(program);
   StatementSemanticsPass2 pass2{context};
   pass2.Walk(program);
   if (context.languageFeatures().IsEnabled(common::LanguageFeature::OpenACC)) {
@@ -217,6 +218,16 @@ static bool PerformStatementSemantics(
   }
   if (context.languageFeatures().IsEnabled(common::LanguageFeature::OpenMP)) {
     SemanticsVisitor<OmpStructureChecker>{context}.Walk(program);
+    if (!context.AnyFatalError()) {
+      // Once semantics have been checked, we can replace any Array Elements
+      // used in Reductions with temporary variables to ensure they are lowered
+      // correctly
+      if (RewriteReductionArrayElements(context, program)) {
+        // If any arrayElements have been rewritten to temp's, the TypedExpr's
+        // need recapturing so pass1 is run again
+        pass1.Walk(program);
+      }
+    }
   }
   if (context.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
     SemanticsVisitor<CUDAChecker>{context}.Walk(program);
diff --git a/flang/test/Semantics/OpenMP/reduction09.f90 b/flang/test/Semantics/OpenMP/reduction09.f90
index ca60805e8c416..080bbb976515d 100644
--- a/flang/test/Semantics/OpenMP/reduction09.f90
+++ b/flang/test/Semantics/OpenMP/reduction09.f90
@@ -1,22 +1,16 @@
-! RUN: %python %S/../test_symbols.py %s %flang_fc1 -fopenmp
+! RUN: %flang_fc1 -fdebug-unparse-with-symbols -fopenmp %s | FileCheck %s
 ! OpenMP Version 4.5
 ! 2.15.3.6 Reduction Clause Positive cases.
 !DEF: /OMP_REDUCTION MainProgram
 program OMP_REDUCTION
-  !DEF: /OMP_REDUCTION/i ObjectEntity INTEGER(4)
   integer i
-  !DEF: /OMP_REDUCTION/k ObjectEntity INTEGER(4)
   integer :: k = 10
-  !DEF: /OMP_REDUCTION/a ObjectEntity INTEGER(4)
   integer a(10)
-  !DEF: /OMP_REDUCTION/b ObjectEntity INTEGER(4)
   integer b(10,10,10)
 
   !$omp parallel  shared(k)
   !$omp do  reduction(+:k)
-  !DEF: /OMP_REDUCTION/OtherConstruct1/OtherConstruct1/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct1/OtherConstruct1/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end do
@@ -24,53 +18,130 @@ program OMP_REDUCTION
 
 
   !$omp parallel do  reduction(+:a(10))
-  !DEF: /OMP_REDUCTION/OtherConstruct2/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct2/k (OmpShared) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end parallel do
 
+  !$omp parallel do  reduction(+:a(10))
+  do i=1,10
+    a(10) = a(10)+1
+  end do
+  !$omp end parallel do
 
   !$omp parallel do  reduction(+:a(1:10:1))
-  !DEF: /OMP_REDUCTION/OtherConstruct3/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct3/k (OmpShared) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end parallel do
 
   !$omp parallel do  reduction(+:b(1:10:1,1:5,2))
-  !DEF: /OMP_REDUCTION/OtherConstruct4/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct4/k (OmpShared) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end parallel do
 
   !$omp parallel do  reduction(+:b(1:10:1,1:5,2:5:1))
-  !DEF: /OMP_REDUCTION/OtherConstruct5/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct5/k (OmpShared) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end parallel do
 
   !$omp parallel  private(i)
   !$omp do reduction(+:k) reduction(+:j)
-  !DEF: /OMP_REDUCTION/OtherConstruct6/OtherConstruct1/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct6/OtherConstruct1/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end do
   !$omp end parallel
 
   !$omp do reduction(+:k) reduction(*:j) reduction(+:l)
-  !DEF: /OMP_REDUCTION/OtherConstruct7/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
   do i=1,10
-    !DEF: /OMP_REDUCTION/OtherConstruct7/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
     k = k+1
   end do
   !$omp end do
 end program OMP_REDUCTION
+
+! CHECK: !DEF: /OMP_REDUCTION MainProgram
+! CHECK-NEXT: program OMP_REDUCTION
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/i ObjectEntity INTEGER(4)
+! CHECK-NEXT:  integer i
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/k ObjectEntity INTEGER(4)
+! CHECK-NEXT:  integer :: k = 10
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/a ObjectEntity INTEGER(4)
+! CHECK-NEXT:  integer a(10)
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/b ObjectEntity INTEGER(4)
+! CHECK-NEXT:  integer b(10,10,10)
+! CHECK-NEXT: !$omp parallel shared(k)
+! CHECK-NEXT: !$omp do reduction(+: k)
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct1/OtherConstruct1/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct1/OtherConstruct1/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end do
+! CHECK-NEXT: !$omp end parallel
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/reduction_temp_a(10) (CompilerCreated) ObjectEntity INTEGER(4)
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/a
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct3/a (OmpShared) HostAssoc INTEGER(4)
+! CHECK-NEXT:  reduction_temp_a(10) = a(10)
+! CHECK-NEXT: !$omp parallel do reduction(+: reduction_temp_a(10))
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct2/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct2/k (OmpShared) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end parallel do
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/reduction_temp_a(10)
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/a
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/OtherConstruct3/a
+! CHECK-NEXT:  reduction_temp_a(10) = a(10)
+! CHECK-NEXT: !$omp parallel do reduction(+: reduction_temp_a(10))
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct3/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !REF: /OMP_REDUCTION/reduction_temp_a(10)
+! CHECK-NEXT:   reduction_temp_a(10) = reduction_temp_a(10)+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end parallel do
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/reduction_temp_a(10)
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/a
+! CHECK-NEXT:  !REF: /OMP_REDUCTION/OtherConstruct3/a
+! CHECK-NEXT:  a(10) = reduction_temp_a(10)
+! CHECK-NEXT: !$omp parallel do reduction(+: a(1:10:1))
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct4/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct4/k (OmpShared) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end parallel do
+! CHECK-NEXT: !$omp parallel do reduction(+: b(1:10:1,1:5,2))
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct5/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct5/k (OmpShared) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end parallel do
+! CHECK-NEXT: !$omp parallel do reduction(+: b(1:10:1,1:5,2:5:1))
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct6/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct6/k (OmpShared) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end parallel do
+! CHECK-NEXT: !$omp parallel private(i)
+! CHECK-NEXT: !$omp do reduction(+: k) reduction(+: j)
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct7/OtherConstruct1/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct7/OtherConstruct1/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end do
+! CHECK-NEXT: !$omp end parallel
+! CHECK-NEXT: !$omp do reduction(+: k) reduction(*: j) reduction(+: l)
+! CHECK-NEXT:  !DEF: /OMP_REDUCTION/OtherConstruct8/i (OmpPrivate, OmpPreDetermined) HostAssoc INTEGER(4)
+! CHECK-NEXT:  do i=1,10
+! CHECK-NEXT:   !DEF: /OMP_REDUCTION/OtherConstruct8/k (OmpReduction, OmpExplicit) HostAssoc INTEGER(4)
+! CHECK-NEXT:   k = k+1
+! CHECK-NEXT:  end do
+! CHECK-NEXT: !$omp end do
+! CHECK-NEXT: end program OMP_REDUCTION
\ No newline at end of file
diff --git a/flang/test/Semantics/OpenMP/reduction17.f90 b/flang/test/Semantics/OpenMP/reduction17.f90
new file mode 100644
index 0000000000000..8e8c476b0a2cf
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/reduction17.f90
@@ -0,0 +1,209 @@
+! This test is targeting the RewriteArrayElements function within rewrite-parse-tree.cpp. Its important that this behaviour is working as otherwise the OpenMP Lowering of ArrayElements in Reduction Clauses will not function correctly.
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck %s --check-prefix=CHECK-TREE
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefix=CHECK-HLFIR
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck %s --check-prefix=CHECK-UNPARSE
+
+program test
+  integer a(2)
+  integer b(2)
+  integer c(2)
+  integer z(10)
+  integer :: k = 10
+
+!! When a scalar array element is used, the array element is replaced with a temprorary so it is correctly lowered as an Integer
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_a(2)=a(2_8)'
+! CHECK-TREE-NEXT: | | | Variable = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | Expr = 'a(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'a'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+!$omp do reduction (+: a(2))
+! CHECK-TREE: | | | | OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+! CHECK-TREE-NEXT: | | | | | Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+! CHECK-TREE-NEXT: | | | | | OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-HLFIR: omp.wsloop private(@_QFEi_private_i32 %11#0 -> %arg0 : !fir.ref<i32>) reduction(@add_reduction_i32 %15#0 -> %arg1 : !fir.ref<i32>) { 
+  do i = 1,2
+    a(2) = a(2) + i
+! CHECK-TREE: | | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_a(2)=reduction_temp_a(2)+i'
+! CHECK-TREE-NEXT: | | | | | | Variable = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | Expr = 'reduction_temp_a(2)+i'
+! CHECK-TREE-NEXT: | | | | | | | Add
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'i'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'i'
+! CHECK-HLFIR: hlfir.declare %arg0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.declare %arg1 {uniq_name = "_QFEreduction_temp_a(2)"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.assign %arg2 to %33#0 : i32, !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %34#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %33#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: arith.addi %35, %36 : i32
+! CHECK-HLFIR-NEXT: hlfir.assign %37 to %34#0 : i32, !fir.ref<i32>
+  end do
+!$omp end do
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'a(2_8)=reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | Variable = 'a(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'a'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+! CHECK-TREE-NEXT: | | | Expr = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+
+!! Ensure that consective reduction clauses can be correctly processed in the same block
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_b(2)=b(2_8)'
+! CHECK-TREE-NEXT: | | | Variable = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | Expr = 'b(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'b'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+!$omp do reduction (+: b(2))
+! CHECK-TREE: | | | | OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+! CHECK-TREE-NEXT: | | | | | Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+! CHECK-TREE-NEXT: | | | | | OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'reduction_temp_b(2)'
+! CHECK-HLFIR: omp.wsloop private(@_QFEi_private_i32 %11#0 -> %arg0 : !fir.ref<i32>) reduction(@add_reduction_i32 %17#0 -> %arg1 : !fir.ref<i32>) {
+  do i = 1,3
+    b(2) = b(2) + i
+! CHECK-TREE: | | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_b(2)=reduction_temp_b(2)+i'
+! CHECK-TREE-NEXT: | | | | | | Variable = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | | | Expr = 'reduction_temp_b(2)+i'
+! CHECK-TREE-NEXT: | | | | | | | Add
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'i'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'i'
+! CHECK-HLFIR: hlfir.declare %arg0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.declare %arg1 {uniq_name = "_QFEreduction_temp_b(2)"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.assign %arg2 to %33#0 : i32, !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %34#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %33#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: arith.addi %35, %36 : i32
+! CHECK-HLFIR-NEXT: hlfir.assign %37 to %34#0 : i32, !fir.ref<i32>
+  end do
+!$omp end do
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'b(2_8)=reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | Variable = 'b(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'b'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+! CHECK-TREE-NEXT: | | | Expr = 'reduction_temp_b(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_b(2)'
+
+!! Ensure that we can reuse the same array element later on. This will use the same symbol as the previous use of a(2) for the temporary value
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_a(2)=a(2_8)'
+! CHECK-TREE-NEXT: | | | Variable = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | Expr = 'a(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'a'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+!$omp do reduction (+: a(2))
+! CHECK-TREE: | | | | OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+! CHECK-TREE-NEXT: | | | | | Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+! CHECK-TREE-NEXT: | | | | | OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-HLFIR: omp.wsloop private(@_QFEi_private_i32 %11#0 -> %arg0 : !fir.ref<i32>) reduction(@add_reduction_i32 %15#0 -> %arg1 : !fir.ref<i32>) {
+  do i = 1,4
+    a(2) = a(2) + i
+! CHECK-TREE: | | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'reduction_temp_a(2)=reduction_temp_a(2)+i'
+! CHECK-TREE-NEXT: | | | | | | Variable = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | Expr = 'reduction_temp_a(2)+i'
+! CHECK-TREE-NEXT: | | | | | | | Add
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | | Expr = 'i'
+! CHECK-TREE-NEXT: | | | | | | | | | Designator -> DataRef -> Name = 'i'
+! CHECK-HLFIR: hlfir.declare %arg0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.declare %arg1 {uniq_name = "_QFEreduction_temp_a(2)"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK-HLFIR-NEXT: hlfir.assign %arg2 to %33#0 : i32, !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %34#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: fir.load %33#0 : !fir.ref<i32>
+! CHECK-HLFIR-NEXT: arith.addi %35, %36 : i32
+! CHECK-HLFIR-NEXT: hlfir.assign %37 to %34#0 : i32, !fir.ref<i32>
+    !! We need to make sure that for the array element that has not been reduced, this does not get replaced with a temp
+    a(1) = a(2)
+! CHECK-TREE: | | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'a(1_8)=reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | Variable = 'a(1_8)'
+! CHECK-TREE-NEXT: | | | | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | | | | DataRef -> Name = 'a'
+! CHECK-TREE-NEXT: | | | | | | | | SectionSubscript -> Integer -> Expr = '1_4'
+! CHECK-TREE-NEXT: | | | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
+! CHECK-TREE-NEXT: | | | | | | Expr = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+! CHECK-HLFIR: arith.constant 1 : index
+! CHECK-HLFIR-NEXT: hlfir.designate %3#0 (%c1)  : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
+! CHECK-HLFIR-NEXT: hlfir.assign %38 to %39 : i32, !fir.ref<i32>
+  end do
+!$omp end do
+! CHECK-TREE: | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AssignmentStmt = 'a(2_8)=reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | Variable = 'a(2_8)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | DataRef -> Name = 'a'
+! CHECK-TREE-NEXT: | | | | | SectionSubscript -> Integer -> Expr = '2_4'
+! CHECK-TREE-NEXT: | | | | | | LiteralConstant -> IntLiteralConstant = '2'
+! CHECK-TREE-NEXT: | | | Expr = 'reduction_temp_a(2)'
+! CHECK-TREE-NEXT: | | | | Designator -> DataRef -> Name = 'reduction_temp_a(2)'
+
+!! Array Sections will not get changed
+  !$omp parallel do reduction(+:z(1:10:1))
+! CHECK-TREE: | | | | OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+! CHECK-TREE-NEXT: | | | | | Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+! CHECK-TREE-NEXT: | | | | | OmpObjectList -> OmpObject -> Designator -> DataRef -> ArrayElement
+! CHECK-TREE-NEXT: | | | | | | DataRef -> Name = 'z'
+! CHECK-TREE-NEXT: | | | | | | SectionSubscript -> SubscriptTriplet
+! CHECK-TREE-NEXT: | | | | | | | Scalar -> Integer -> Expr = '1_4'
+! CHECK-TREE-NEXT: | | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
+! CHECK-TREE-NEXT: | | | | | | | Scalar -> Integer -> Expr = '10_4'
+! CHECK-TREE-NEXT: | | | | | | | | LiteralConstant -> IntLiteralConstant = '10'
+! CHECK-TREE-NEXT: | | | | | | | Scalar -> Integer -> Expr = '1_4'
+! CHECK-TREE-NEXT: | | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
+! CHECK-HLFIR: omp.wsloop private(@_QFEi_private_i32 %11#0 -> %arg0 : !fir.ref<i32>) reduction(byref @add_reduction_byref_box_10xi32 %34 -> %arg1 : !fir.ref<!fir.box<!fir.array<10xi32>>>) {
+  do i=1,10
+    k = k + 1
+  end do
+  !$omp end parallel do
+
+end program test
+
+! CHECK-UNPARSE: PROGRAM TEST
+! CHECK-UNPARSE-NEXT:  INTEGER a(2_4)
+! CHECK-UNPARSE-NEXT:  INTEGER b(2_4)
+! CHECK-UNPARSE-NEXT:  INTEGER c(2_4)
+! CHECK-UNPARSE-NEXT:  INTEGER z(10_4)
+! CHECK-UNPARSE-NEXT:  INTEGER :: k = 10_4
+! CHECK-UNPARSE-NEXT:   reduction_temp_a(2)=a(2_8)
+! CHECK-UNPARSE-NEXT: !$OMP DO REDUCTION(+: reduction_temp_a(2))
+! CHECK-UNPARSE-NEXT:  DO i=1_4,2_4
+! CHECK-UNPARSE-NEXT:    reduction_temp_a(2)=reduction_temp_a(2)+i
+! CHECK-UNPARSE-NEXT:  END DO
+! CHECK-UNPARSE-NEXT: !$OMP END DO
+! CHECK-UNPARSE-NEXT:   a(2_8)=reduction_temp_a(2)
+! CHECK-UNPARSE-NEXT:   reduction_temp_b(2)=b(2_8)
+! CHECK-UNPARSE-NEXT: !$OMP DO REDUCTION(+: reduction_temp_b(2))
+! CHECK-UNPARSE-NEXT:  DO i=1_4,3_4
+! CHECK-UNPARSE-NEXT:    reduction_temp_b(2)=reduction_temp_b(2)+i
+! CHECK-UNPARSE-NEXT:  END DO
+! CHECK-UNPARSE-NEXT: !$OMP END DO
+! CHECK-UNPARSE-NEXT:   b(2_8)=reduction_temp_b(2)
+! CHECK-UNPARSE-NEXT:   reduction_temp_a(2)=a(2_8)
+! CHECK-UNPARSE-NEXT: !$OMP DO REDUCTION(+: reduction_temp_a(2))
+! CHECK-UNPARSE-NEXT:  DO i=1_4,4_4
+! CHECK-UNPARSE-NEXT:    reduction_temp_a(2)=reduction_temp_a(2)+i
+! CHECK-UNPARSE-NEXT:    a(1_8)=reduction_temp_a(2)
+! CHECK-UNPARSE-NEXT:  END DO
+! CHECK-UNPARSE-NEXT: !$OMP END DO
+! CHECK-UNPARSE-NEXT:   a(2_8)=reduction_temp_a(2)
+! CHECK-UNPARSE-NEXT: !$OMP PARALLEL DO REDUCTION(+: z(1_4:10_4:1_4))
+! CHECK-UNPARSE-NEXT:  DO i=1_4,10_4
+! CHECK-UNPARSE-NEXT:    k=k+1_4
+! CHECK-UNPARSE-NEXT:  END DO
+! CHECK-UNPARSE-NEXT: !$OMP END PARALLEL DO
+! CHECK-UNPARSE-NEXT: END PROGRAM TEST

>From 8bf23605feeff800fc3726575eb5dadf59613dc4 Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Mon, 20 Oct 2025 09:06:30 +0100
Subject: [PATCH 2/2] Code Style improvements & respond to comments

collate all std::visit's into lammbda functions to make them easier
to read.
---
 flang/lib/Semantics/rewrite-parse-tree.cpp | 458 +++++++++++----------
 1 file changed, 236 insertions(+), 222 deletions(-)

diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 5379dcdd3d40c..adbb6b1d03893 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -575,10 +575,9 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
 void ReplacementTemp::createTempSymbol(
     SourceName &source, Scope &scope, SemanticsContext &context) {
   replacementTempSymbol_ =
-      const_cast<semantics::Scope &>(originalName_.symbol->owner())
-          .FindSymbol(source);
+      const_cast<Scope &>(originalName_.symbol->owner()).FindSymbol(source);
   replacementTempSymbol_->set_scope(
-      &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+      &const_cast<Scope &>(originalName_.symbol->owner()));
   DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
   replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
   replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
@@ -587,27 +586,28 @@ void ReplacementTemp::createTempSymbol(
 void ReplacementTemp::setOriginalSubscriptInt(
     std::list<parser::SectionSubscript> &sectionSubscript) {
   bool setSubscript{false};
-  for (parser::SectionSubscript &subscript : sectionSubscript) {
+  auto visitLiteralConstant = [&](parser::LiteralConstant &literalConstant) {
     std::visit(llvm::makeVisitor(
-                   [&](parser::IntExpr &intExpr) {
-                     parser::Expr &expr = intExpr.thing.value();
-                     std::visit(
-                         llvm::makeVisitor(
-                             [&](parser::LiteralConstant &literalContant) {
-                               std::visit(llvm::makeVisitor(
-                                              [&](parser::IntLiteralConstant
-                                                      &intLiteralConstant) {
-                                                originalSubscriptCharBlock_ =
-                                                    std::get<parser::CharBlock>(
-                                                        intLiteralConstant.t);
-                                                setSubscript = true;
-                                              },
-                                              [&](auto &) {}),
-                                   literalContant.u);
-                             },
-                             [&](auto &) {}),
-                         expr.u);
+                   [&](parser::IntLiteralConstant &intLiteralConstant) {
+                     originalSubscriptCharBlock_ =
+                         std::get<parser::CharBlock>(intLiteralConstant.t);
+                     setSubscript = true;
                    },
+                   [&](auto &) {}),
+        literalConstant.u);
+  };
+  auto visitIntExpr = [&](parser::IntExpr &intExpr) {
+    parser::Expr &expr = intExpr.thing.value();
+    std::visit(llvm::makeVisitor(
+                   [&](parser::LiteralConstant &literalConstant) {
+                     visitLiteralConstant(literalConstant);
+                   },
+                   [&](auto &) {}),
+        expr.u);
+  };
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) { visitIntExpr(intExpr); },
                    [&](parser::SubscriptTriplet &triplet) {
                      isSectionTriplet_ = true;
                      setSubscript = true;
@@ -626,47 +626,53 @@ void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
     return;
   }
 
-  for (auto it{block.begin()}; it != block.end(); ++it) {
+  auto visitOpenMPLoopConstruct = [&](parser::OpenMPLoopConstruct &ompLoop,
+                                      parser::Block::iterator &it) {
+    ReplacementTemp temp;
+    if (!rewriteArrayElementToTemp(it, ompLoop, block, temp)) {
+      return;
+    }
+    auto &NestedConstruct =
+        std::get<std::optional<parser::NestedConstruct>>(ompLoop.t);
+    if (!NestedConstruct.has_value()) {
+      return;
+    }
+    if (parser::DoConstruct *
+        doConst{std::get_if<parser::DoConstruct>(&NestedConstruct.value())}) {
+      block_ = █
+      parser::Block &doBlock{std::get<parser::Block>(doConst->t)};
+      parser::Walk(doBlock, *this);
+      // Reset the current temp value so future
+      // iterations use their own version.
+      resetCurrentTemp();
+    }
+  };
+  auto visitOpenMPConstruct = [&](parser::OpenMPConstruct &ompConstruct,
+                                  parser::Block::iterator &it) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::OpenMPLoopConstruct &ompLoop) {
+                     visitOpenMPLoopConstruct(ompLoop, it);
+                   },
+                   [&](auto &) {}),
+        ompConstruct.u);
+  };
+  auto visitExecutableConstruct = [&](parser::ExecutableConstruct
+                                          &execConstruct,
+                                      parser::Block::iterator &it) {
     std::visit(
         llvm::makeVisitor(
-            [&](parser::ExecutableConstruct &execConstruct) {
-              std::visit(
-                  llvm::makeVisitor(
-                      [&](common::Indirection<parser::OpenMPConstruct>
-                              &ompConstruct) {
-                        std::visit(
-                            llvm::makeVisitor(
-                                [&](parser::OpenMPLoopConstruct &ompLoop) {
-                                  ReplacementTemp temp;
-                                  if (!rewriteArrayElementToTemp(
-                                          it, ompLoop, block, temp)) {
-                                    return;
-                                  }
-                                  auto &NestedConstruct = std::get<
-                                      std::optional<parser::NestedConstruct>>(
-                                      ompLoop.t);
-                                  if (!NestedConstruct.has_value()) {
-                                    return;
-                                  }
-                                  if (parser::DoConstruct *
-                                      doConst{std::get_if<parser::DoConstruct>(
-                                          &NestedConstruct.value())}) {
-                                    block_ = █
-                                    parser::Block &doBlock{
-                                        std::get<parser::Block>(doConst->t)};
-                                    parser::Walk(doBlock, *this);
-                                    // Reset the current temp value so future
-                                    // iterations use their own version.
-                                    resetCurrentTemp();
-                                  }
-                                },
-                                [&](auto &) {}),
-                            ompConstruct.value().u);
-                      },
-                      [&](auto &) {}),
-                  execConstruct.u);
+            [&](common::Indirection<parser::OpenMPConstruct> &ompConstruct) {
+              visitOpenMPConstruct(ompConstruct.value(), it);
             },
             [&](auto &) {}),
+        execConstruct.u);
+  };
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::ExecutableConstruct &execConstruct) {
+                     visitExecutableConstruct(execConstruct, it);
+                   },
+                   [&](auto &) {}),
         it->u);
   }
 }
@@ -676,76 +682,84 @@ bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
   bool matchesArrayElement{false};
   std::list<parser::SectionSubscript> *subscripts{nullptr};
 
+  auto visitName = [&](parser::Name &name, parser::ArrayElement &arrayElement) {
+    if (name.symbol->GetUltimate() ==
+        currentTemp_->getOriginalName().symbol->GetUltimate()) {
+      matchesArrayElement = true;
+      if (!currentTemp_->isArrayElementReassigned()) {
+        reassignTempValueToArrayElement(arrayElement);
+      }
+    }
+  };
+  auto visitArratElement = [&](parser::ArrayElement &arrayElement) {
+    subscripts = &arrayElement.subscripts;
+    std::visit(llvm::makeVisitor(
+                   [&](parser::Name &name) { visitName(name, arrayElement); },
+                   [](auto &) {}),
+        arrayElement.base.u);
+  };
+  auto visitDataRef = [&](parser::DataRef &dataRef) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::ArrayElement> &arrayElement) {
+              visitArratElement(arrayElement.value());
+            },
+            [&](parser::Name &name) {
+              if (name.symbol->GetUltimate() ==
+                  currentTemp_->getOriginalName().symbol->GetUltimate()) {
+                matchesArrayElement = true;
+              }
+            },
+            [](auto &) {}),
+        dataRef.u);
+  };
   std::visit(llvm::makeVisitor(
-                 [&](parser::DataRef &dataRef) {
-                   std::visit(
-                       llvm::makeVisitor(
-                           [&](common::Indirection<parser::ArrayElement>
-                                   &arrayElement) {
-                             subscripts = &arrayElement.value().subscripts;
-                             std::visit(
-                                 llvm::makeVisitor(
-                                     [&](parser::Name &name) {
-                                       if (name.symbol->GetUltimate() ==
-                                           currentTemp_->getOriginalName()
-                                               .symbol->GetUltimate()) {
-                                         matchesArrayElement = true;
-                                         if (!currentTemp_
-                                                 ->isArrayElementReassigned()) {
-                                           reassignTempValueToArrayElement(
-                                               arrayElement.value());
-                                         }
-                                       }
-                                     },
-                                     [](auto &) {}),
-                                 arrayElement.value().base.u);
-                           },
-                           [&](parser::Name &name) {
-                             if (name.symbol->GetUltimate() ==
-                                 currentTemp_->getOriginalName()
-                                     .symbol->GetUltimate()) {
-                               matchesArrayElement = true;
-                             }
-                           },
-                           [](auto &) {}),
-                       dataRef.u);
-                 },
+                 [&](parser::DataRef &dataRef) { visitDataRef(dataRef); },
                  [&](auto &) {}),
       existingDesignator.u);
 
   if (subscripts) {
     bool foundSubscript{false};
-    for (parser::SectionSubscript &subscript : *subscripts) {
-      matchesArrayElement = std::visit(
+    auto visitIntLiteralConstant =
+        [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+      foundSubscript = true;
+      assert(currentTemp_ != nullptr &&
+          "Value for ReplacementTemp should have "
+          "been found");
+      if (std::get<parser::CharBlock>(intLiteralConstant.t) ==
+          currentTemp_->getOriginalSubscript()) {
+        return true;
+      }
+      return false;
+    };
+    auto visitLiteralConstant =
+        [&](parser::LiteralConstant &literalConstant) -> bool {
+      return std::visit(
           llvm::makeVisitor(
-              [&](parser::IntExpr &intExpr) -> bool {
-                parser::Expr &expr = intExpr.thing.value();
-                return std::visit(
-                    llvm::makeVisitor(
-                        [&](parser::LiteralConstant &literalContant) -> bool {
-                          return std::visit(
-                              llvm::makeVisitor(
-                                  [&](parser::IntLiteralConstant
-                                          &intLiteralConstant) -> bool {
-                                    foundSubscript = true;
-                                    assert(currentTemp_ != nullptr &&
-                                        "Value for ReplacementTemp should have "
-                                        "been found");
-                                    if (std::get<parser::CharBlock>(
-                                            intLiteralConstant.t) ==
-                                        currentTemp_->getOriginalSubscript()) {
-                                      return true;
-                                    }
-                                    return false;
-                                  },
-                                  [](auto &) -> bool { return false; }),
-                              literalContant.u);
-                        },
-                        [](auto &) -> bool { return false; }),
-                    expr.u);
+              [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+                return visitIntLiteralConstant(intLiteralConstant);
+              },
+              [](auto &) -> bool { return false; }),
+          literalConstant.u);
+    };
+    auto visitIntExpr = [&](parser::IntExpr &intExpr) -> bool {
+      parser::Expr &expr = intExpr.thing.value();
+      return std::visit(
+          llvm::makeVisitor(
+              [&](parser::LiteralConstant &literalConstant) -> bool {
+                return visitLiteralConstant(literalConstant);
               },
               [](auto &) -> bool { return false; }),
-          subscript.u);
+          expr.u);
+    };
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement =
+          std::visit(llvm::makeVisitor(
+                         [&](parser::IntExpr &intExpr) -> bool {
+                           return visitIntExpr(intExpr);
+                         },
+                         [](auto &) -> bool { return false; }),
+              subscript.u);
       if (foundSubscript) {
         break;
       }
@@ -757,16 +771,21 @@ bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
 template <typename T>
 void RewriteOmpReductionArrayElements::processFunctionReference(
     T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto visitFunctionReferenceName = [&](parser::Name &functionReferenceName)
+      -> std::optional<parser::Designator> {
+    if (currentTemp_->getOriginalName().symbol ==
+        functionReferenceName.symbol) {
+      return funcRef.ConvertToArrayElementRef();
+    }
+    return std::nullopt;
+  };
+
   auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
   std::optional<parser::Designator> arrayElementDesignator =
       std::visit(llvm::makeVisitor(
                      [&](parser::Name &functionReferenceName)
                          -> std::optional<parser::Designator> {
-                       if (currentTemp_->getOriginalName().symbol ==
-                           functionReferenceName.symbol) {
-                         return funcRef.ConvertToArrayElementRef();
-                       }
-                       return std::nullopt;
+                       return visitFunctionReferenceName(functionReferenceName);
                      },
                      [&](auto &) -> std::optional<parser::Designator> {
                        return std::nullopt;
@@ -798,86 +817,79 @@ bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
       std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
   std::list<parser::OmpClause> &clauseList{
       std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
-  bool rewrittenArrayElement{false};
 
-  for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+  auto visitDesignator = [&](parser::Designator &designator) {
+    if (!identifyArrayElementReduced(designator, temp)) {
+      return false;
+    }
+    if (temp.isSectionTriplet()) {
+      return false;
+    }
+
+    reassignmentInsertionPoint_ = it;
+    std::string tempSourceString = "reduction_temp_" +
+        temp.getOriginalSource().ToString() + "(" +
+        temp.getOriginalSubscript().ToString() + ")";
+    SourceName source = context_.SaveTempName(std::move(tempSourceString));
+    Scope &scope = const_cast<Scope &>(temp.getOriginalName().symbol->owner());
+    if (Symbol * symbol{scope.FindSymbol(source)}) {
+      temp.setTempSymbol(symbol);
+    } else {
+      if (scope.try_emplace(source, Attrs{}, ObjectEntityDetails{}).second) {
+        temp.createTempSymbol(source, scope, context_);
+      } else {
+        common::die(
+            "Failed to create temp symbol for %s", source.ToString().c_str());
+      }
+    }
+    setCurrentTemp(&temp);
+    temp.setTempScope(scope);
+
+    // Assign the value of the array element to the
+    // temporary variable
+    parser::Variable newVariable{makeTempDesignator(temp.getOriginalSource())};
+    parser::Expr newExpr{
+        common::Indirection<parser::Designator>{std::move(designator)}};
+    newExpr.source = temp.getOriginalSource();
+    std::tuple<parser::Variable, parser::Expr> newT{
+        std::move(newVariable), std::move(newExpr)};
+    parser::AssignmentStmt assignment{std::move(newT)};
+    parser::ExecutionPartConstruct tempVariablePartConstruct{
+        parser::ExecutionPartConstruct{
+            parser::ExecutableConstruct{parser::Statement<parser::ActionStmt>{
+                std::optional<parser::Label>{}, std::move(assignment)}}}};
+    block.insert(it, std::move(tempVariablePartConstruct));
+    arrayElementReassigned_ = true;
+
+    designator = makeTempDesignator(temp.getOriginalSource());
+    return true;
+  };
+  auto processReductionClause =
+      [&](parser::OmpClause::Reduction &clause) -> bool {
+    std::list<parser::OmpObject> &objectList =
+        std::get<parser::OmpObjectList>(clause.v.t).v;
+
+    bool rewritten{false};
+    for (parser::OmpObject &object : objectList) {
+      rewritten |= std::visit(llvm::makeVisitor(
+                                  [&](parser::Designator &designator) -> bool {
+                                    return visitDesignator(designator);
+                                  },
+                                  [&](const auto &) -> bool { return false; }),
+          object.u);
+    };
+    return rewritten;
+  };
+
+  bool rewrittenArrayElement{false};
+  for (parser::OmpClause &clause : clauseList) {
     rewrittenArrayElement = std::visit(
         llvm::makeVisitor(
-            [&](parser::OmpClause::Reduction &clause) -> bool {
-              std::list<parser::OmpObject> &objectList =
-                  std::get<parser::OmpObjectList>(clause.v.t).v;
-
-              bool rewritten{false};
-              for (auto object{objectList.begin()}; object != objectList.end();
-                  ++object) {
-                rewritten = std::visit(
-                    llvm::makeVisitor(
-                        [&](parser::Designator &designator) -> bool {
-                          if (!identifyArrayElementReduced(designator, temp)) {
-                            return false;
-                          }
-                          if (temp.isSectionTriplet()) {
-                            return false;
-                          }
-
-                          reassignmentInsertionPoint_ =
-                              it != block.end() ? it : block.end();
-                          std::string tempSourceString = "reduction_temp_" +
-                              temp.getOriginalSource().ToString() + "(" +
-                              temp.getOriginalSubscript().ToString() + ")";
-                          SourceName source = context_.SaveTempName(
-                              std::move(tempSourceString));
-                          Scope &scope = const_cast<Scope &>(
-                              temp.getOriginalName().symbol->owner());
-                          if (Symbol * symbol{scope.FindSymbol(source)}) {
-                            temp.setTempSymbol(symbol);
-                          } else {
-                            if (scope
-                                    .try_emplace(source, semantics::Attrs{},
-                                        semantics::ObjectEntityDetails{})
-                                    .second) {
-                              temp.createTempSymbol(source, scope, context_);
-                            } else {
-                              common::die("Failed to create temp symbol for %s",
-                                  source.ToString().c_str());
-                            }
-                          }
-                          setCurrentTemp(&temp);
-                          temp.setTempScope(scope);
-
-                          // Assign the value of the array element to the
-                          // temporary variable
-                          parser::Variable newVariable{
-                              makeTempDesignator(temp.getOriginalSource())};
-                          parser::Expr newExpr{
-                              common::Indirection<parser::Designator>{
-                                  std::move(designator)}};
-                          newExpr.source = temp.getOriginalSource();
-                          std::tuple<parser::Variable, parser::Expr> newT{
-                              std::move(newVariable), std::move(newExpr)};
-                          parser::AssignmentStmt assignment{std::move(newT)};
-                          parser::ExecutionPartConstruct
-                              tempVariablePartConstruct{
-                                  parser::ExecutionPartConstruct{
-                                      parser::ExecutableConstruct{
-                                          parser::Statement<parser::ActionStmt>{
-                                              std::optional<parser::Label>{},
-                                              std::move(assignment)}}}};
-                          block.insert(
-                              it, std::move(tempVariablePartConstruct));
-                          arrayElementReassigned_ = true;
-
-                          designator =
-                              makeTempDesignator(temp.getOriginalSource());
-                          return true;
-                        },
-                        [&](const auto &) -> bool { return false; }),
-                    object->u);
-              }
-              return rewritten;
+            [&](parser::OmpClause::Reduction &reductionClause) -> bool {
+              return processReductionClause(reductionClause);
             },
             [&](auto &) -> bool { return false; }),
-        iter->u);
+        clause.u);
 
     if (rewrittenArrayElement) {
       return rewrittenArrayElement;
@@ -888,28 +900,30 @@ bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
 
 bool RewriteOmpReductionArrayElements::identifyArrayElementReduced(
     parser::Designator &designator, ReplacementTemp &temp) {
-  return std::visit(
-      llvm::makeVisitor(
-          [&](parser::DataRef &dataRef) -> bool {
-            return std::visit(
-                llvm::makeVisitor(
-                    [&](common::Indirection<parser::ArrayElement>
-                            &arrayElement) {
-                      std::visit(llvm::makeVisitor(
-                                     [&](parser::Name &name) -> void {
-                                       temp.setOriginalName(name);
-                                       temp.setOriginalSource(name.source);
-                                     },
-                                     [&](auto &) -> void {}),
-                          arrayElement.value().base.u);
-                      temp.setOriginalSubscriptInt(
-                          arrayElement.value().subscripts);
-                      return !temp.isSectionTriplet() ? true : false;
-                    },
-                    [&](auto &) -> bool { return false; }),
-                dataRef.u);
-          },
-          [&](auto &) -> bool { return false; }),
+  auto visitArrayElement = [&](parser::ArrayElement &arrayElement) -> bool {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::Name &name) -> void {
+                     temp.setOriginalName(name);
+                     temp.setOriginalSource(name.source);
+                   },
+                   [&](auto &) -> void {}),
+        arrayElement.base.u);
+    temp.setOriginalSubscriptInt(arrayElement.subscripts);
+    return !temp.isSectionTriplet();
+  };
+  auto visitDataRef = [&](parser::DataRef &dataRef) -> bool {
+    return std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::ArrayElement> &arrayElement)
+                -> bool { return visitArrayElement(arrayElement.value()); },
+            [&](auto &) -> bool { return false; }),
+        dataRef.u);
+  };
+  return std::visit(llvm::makeVisitor(
+                        [&](parser::DataRef &dataRef) -> bool {
+                          return visitDataRef(dataRef);
+                        },
+                        [&](auto &) -> bool { return false; }),
       designator.u);
 }
 



More information about the flang-commits mailing list