[flang-commits] [flang] [flang][OpenMP] Improve reduction of Scalar ArrayElement types (PR #163940)

Jack Styles via flang-commits flang-commits at lists.llvm.org
Mon Oct 20 03:58:43 PDT 2025


================
@@ -492,10 +572,454 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
   FixMisparsedUntaggedNamelistName(x);
 }
 
+void ReplacementTemp::createTempSymbol(
+    SourceName &source, Scope &scope, SemanticsContext &context) {
+  replacementTempSymbol_ =
+      const_cast<semantics::Scope &>(originalName_.symbol->owner())
+          .FindSymbol(source);
+  replacementTempSymbol_->set_scope(
+      &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+  DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+  replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+  replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+    std::list<parser::SectionSubscript> &sectionSubscript) {
+  bool setSubscript{false};
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) {
+                     parser::Expr &expr = intExpr.thing.value();
+                     std::visit(
+                         llvm::makeVisitor(
+                             [&](parser::LiteralConstant &literalContant) {
+                               std::visit(llvm::makeVisitor(
+                                              [&](parser::IntLiteralConstant
+                                                      &intLiteralConstant) {
+                                                originalSubscriptCharBlock_ =
+                                                    std::get<parser::CharBlock>(
+                                                        intLiteralConstant.t);
+                                                setSubscript = true;
+                                              },
+                                              [&](auto &) {}),
+                                   literalContant.u);
+                             },
+                             [&](auto &) {}),
+                         expr.u);
+                   },
+                   [&](parser::SubscriptTriplet &triplet) {
+                     isSectionTriplet_ = true;
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        subscript.u);
+    if (setSubscript) {
+      break;
+    }
+  }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+    parser::Block &block) {
+  if (block.empty()) {
+    return;
+  }
+
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](parser::ExecutableConstruct &execConstruct) {
+              std::visit(
+                  llvm::makeVisitor(
+                      [&](common::Indirection<parser::OpenMPConstruct>
+                              &ompConstruct) {
+                        std::visit(
+                            llvm::makeVisitor(
+                                [&](parser::OpenMPLoopConstruct &ompLoop) {
+                                  ReplacementTemp temp;
+                                  if (!rewriteArrayElementToTemp(
+                                          it, ompLoop, block, temp)) {
+                                    return;
+                                  }
+                                  auto &NestedConstruct = std::get<
+                                      std::optional<parser::NestedConstruct>>(
+                                      ompLoop.t);
+                                  if (!NestedConstruct.has_value()) {
+                                    return;
+                                  }
+                                  if (parser::DoConstruct *
+                                      doConst{std::get_if<parser::DoConstruct>(
+                                          &NestedConstruct.value())}) {
+                                    block_ = █
+                                    parser::Block &doBlock{
+                                        std::get<parser::Block>(doConst->t)};
+                                    parser::Walk(doBlock, *this);
+                                    // Reset the current temp value so future
+                                    // iterations use their own version.
+                                    resetCurrentTemp();
+                                  }
+                                },
+                                [&](auto &) {}),
+                            ompConstruct.value().u);
+                      },
+                      [&](auto &) {}),
+                  execConstruct.u);
+            },
+            [&](auto &) {}),
+        it->u);
+  }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+    parser::Designator &existingDesignator) {
+  bool matchesArrayElement{false};
+  std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+  std::visit(llvm::makeVisitor(
+                 [&](parser::DataRef &dataRef) {
+                   std::visit(
+                       llvm::makeVisitor(
+                           [&](common::Indirection<parser::ArrayElement>
+                                   &arrayElement) {
+                             subscripts = &arrayElement.value().subscripts;
+                             std::visit(
+                                 llvm::makeVisitor(
+                                     [&](parser::Name &name) {
+                                       if (name.symbol->GetUltimate() ==
+                                           currentTemp_->getOriginalName()
+                                               .symbol->GetUltimate()) {
+                                         matchesArrayElement = true;
+                                         if (!currentTemp_
+                                                 ->isArrayElementReassigned()) {
+                                           reassignTempValueToArrayElement(
+                                               arrayElement.value());
+                                         }
+                                       }
+                                     },
+                                     [](auto &) {}),
+                                 arrayElement.value().base.u);
+                           },
+                           [&](parser::Name &name) {
+                             if (name.symbol->GetUltimate() ==
+                                 currentTemp_->getOriginalName()
+                                     .symbol->GetUltimate()) {
+                               matchesArrayElement = true;
+                             }
+                           },
+                           [](auto &) {}),
+                       dataRef.u);
+                 },
+                 [&](auto &) {}),
+      existingDesignator.u);
+
+  if (subscripts) {
+    bool foundSubscript{false};
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement = std::visit(
+          llvm::makeVisitor(
+              [&](parser::IntExpr &intExpr) -> bool {
+                parser::Expr &expr = intExpr.thing.value();
+                return std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::LiteralConstant &literalContant) -> bool {
+                          return std::visit(
+                              llvm::makeVisitor(
+                                  [&](parser::IntLiteralConstant
+                                          &intLiteralConstant) -> bool {
+                                    foundSubscript = true;
+                                    assert(currentTemp_ != nullptr &&
+                                        "Value for ReplacementTemp should have "
+                                        "been found");
+                                    if (std::get<parser::CharBlock>(
+                                            intLiteralConstant.t) ==
+                                        currentTemp_->getOriginalSubscript()) {
+                                      return true;
+                                    }
+                                    return false;
+                                  },
+                                  [](auto &) -> bool { return false; }),
+                              literalContant.u);
+                        },
+                        [](auto &) -> bool { return false; }),
+                    expr.u);
+              },
+              [](auto &) -> bool { return false; }),
+          subscript.u);
+      if (foundSubscript) {
+        break;
+      }
+    }
+  }
+  return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+    T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+  std::optional<parser::Designator> arrayElementDesignator =
+      std::visit(llvm::makeVisitor(
+                     [&](parser::Name &functionReferenceName)
+                         -> std::optional<parser::Designator> {
+                       if (currentTemp_->getOriginalName().symbol ==
+                           functionReferenceName.symbol) {
+                         return funcRef.ConvertToArrayElementRef();
+                       }
+                       return std::nullopt;
+                     },
+                     [&](auto &) -> std::optional<parser::Designator> {
+                       return std::nullopt;
+                     }),
+          procedureDesignator.u);
+
+  if (arrayElementDesignator.has_value()) {
+    if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+      node = T{
+          common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+    }
+  }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+    parser::CharBlock source) {
+  parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+  tempVariableName.symbol = currentTemp_->getTempSymbol();
+  parser::Designator tempDesignator{
+      parser::DataRef{std::move(tempVariableName)}};
+  tempDesignator.source = source;
+  return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+    parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+    parser::Block &block, ReplacementTemp &temp) {
+  parser::OmpBeginLoopDirective &ompBeginLoop{
+      std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+  std::list<parser::OmpClause> &clauseList{
+      std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
+  bool rewrittenArrayElement{false};
+
+  for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+    rewrittenArrayElement = std::visit(
+        llvm::makeVisitor(
+            [&](parser::OmpClause::Reduction &clause) -> bool {
+              std::list<parser::OmpObject> &objectList =
+                  std::get<parser::OmpObjectList>(clause.v.t).v;
+
+              bool rewritten{false};
+              for (auto object{objectList.begin()}; object != objectList.end();
+                  ++object) {
----------------
Stylie777 wrote:

I have done this for all but one for loop. The for loop that has been left is the for loop for the block, as the iterator is needed there for inserting the assignment statement of the array element to the temporary

https://github.com/llvm/llvm-project/pull/163940


More information about the flang-commits mailing list