[flang-commits] [flang] [flang][OpenMP] Improve reduction of Scalar ArrayElement types (PR #163940)
    Jack Styles via flang-commits 
    flang-commits at lists.llvm.org
       
    Mon Oct 20 03:58:43 PDT 2025
    
    
  
================
@@ -492,10 +572,454 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
   FixMisparsedUntaggedNamelistName(x);
 }
 
+void ReplacementTemp::createTempSymbol(
+    SourceName &source, Scope &scope, SemanticsContext &context) {
+  replacementTempSymbol_ =
+      const_cast<semantics::Scope &>(originalName_.symbol->owner())
+          .FindSymbol(source);
+  replacementTempSymbol_->set_scope(
+      &const_cast<semantics::Scope &>(originalName_.symbol->owner()));
+  DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+  replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+  replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+    std::list<parser::SectionSubscript> §ionSubscript) {
+  bool setSubscript{false};
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) {
+                     parser::Expr &expr = intExpr.thing.value();
+                     std::visit(
+                         llvm::makeVisitor(
+                             [&](parser::LiteralConstant &literalContant) {
+                               std::visit(llvm::makeVisitor(
+                                              [&](parser::IntLiteralConstant
+                                                      &intLiteralConstant) {
+                                                originalSubscriptCharBlock_ =
+                                                    std::get<parser::CharBlock>(
+                                                        intLiteralConstant.t);
+                                                setSubscript = true;
+                                              },
+                                              [&](auto &) {}),
+                                   literalContant.u);
+                             },
+                             [&](auto &) {}),
+                         expr.u);
+                   },
+                   [&](parser::SubscriptTriplet &triplet) {
+                     isSectionTriplet_ = true;
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        subscript.u);
+    if (setSubscript) {
+      break;
+    }
+  }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+    parser::Block &block) {
+  if (block.empty()) {
+    return;
+  }
+
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](parser::ExecutableConstruct &execConstruct) {
+              std::visit(
+                  llvm::makeVisitor(
+                      [&](common::Indirection<parser::OpenMPConstruct>
+                              &ompConstruct) {
+                        std::visit(
+                            llvm::makeVisitor(
+                                [&](parser::OpenMPLoopConstruct &ompLoop) {
+                                  ReplacementTemp temp;
+                                  if (!rewriteArrayElementToTemp(
+                                          it, ompLoop, block, temp)) {
+                                    return;
+                                  }
+                                  auto &NestedConstruct = std::get<
+                                      std::optional<parser::NestedConstruct>>(
+                                      ompLoop.t);
+                                  if (!NestedConstruct.has_value()) {
+                                    return;
+                                  }
+                                  if (parser::DoConstruct *
+                                      doConst{std::get_if<parser::DoConstruct>(
+                                          &NestedConstruct.value())}) {
+                                    block_ = █
+                                    parser::Block &doBlock{
+                                        std::get<parser::Block>(doConst->t)};
+                                    parser::Walk(doBlock, *this);
+                                    // Reset the current temp value so future
+                                    // iterations use their own version.
+                                    resetCurrentTemp();
+                                  }
+                                },
+                                [&](auto &) {}),
+                            ompConstruct.value().u);
+                      },
+                      [&](auto &) {}),
+                  execConstruct.u);
+            },
+            [&](auto &) {}),
+        it->u);
+  }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+    parser::Designator &existingDesignator) {
+  bool matchesArrayElement{false};
+  std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+  std::visit(llvm::makeVisitor(
+                 [&](parser::DataRef &dataRef) {
+                   std::visit(
+                       llvm::makeVisitor(
+                           [&](common::Indirection<parser::ArrayElement>
+                                   &arrayElement) {
+                             subscripts = &arrayElement.value().subscripts;
+                             std::visit(
+                                 llvm::makeVisitor(
+                                     [&](parser::Name &name) {
+                                       if (name.symbol->GetUltimate() ==
+                                           currentTemp_->getOriginalName()
+                                               .symbol->GetUltimate()) {
+                                         matchesArrayElement = true;
+                                         if (!currentTemp_
+                                                 ->isArrayElementReassigned()) {
+                                           reassignTempValueToArrayElement(
+                                               arrayElement.value());
+                                         }
+                                       }
+                                     },
+                                     [](auto &) {}),
+                                 arrayElement.value().base.u);
+                           },
+                           [&](parser::Name &name) {
+                             if (name.symbol->GetUltimate() ==
+                                 currentTemp_->getOriginalName()
+                                     .symbol->GetUltimate()) {
+                               matchesArrayElement = true;
+                             }
+                           },
+                           [](auto &) {}),
+                       dataRef.u);
+                 },
+                 [&](auto &) {}),
+      existingDesignator.u);
+
+  if (subscripts) {
+    bool foundSubscript{false};
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement = std::visit(
+          llvm::makeVisitor(
+              [&](parser::IntExpr &intExpr) -> bool {
+                parser::Expr &expr = intExpr.thing.value();
+                return std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::LiteralConstant &literalContant) -> bool {
+                          return std::visit(
+                              llvm::makeVisitor(
+                                  [&](parser::IntLiteralConstant
+                                          &intLiteralConstant) -> bool {
+                                    foundSubscript = true;
+                                    assert(currentTemp_ != nullptr &&
+                                        "Value for ReplacementTemp should have "
+                                        "been found");
+                                    if (std::get<parser::CharBlock>(
+                                            intLiteralConstant.t) ==
+                                        currentTemp_->getOriginalSubscript()) {
+                                      return true;
+                                    }
+                                    return false;
+                                  },
+                                  [](auto &) -> bool { return false; }),
+                              literalContant.u);
+                        },
+                        [](auto &) -> bool { return false; }),
+                    expr.u);
+              },
+              [](auto &) -> bool { return false; }),
+          subscript.u);
+      if (foundSubscript) {
+        break;
+      }
+    }
+  }
+  return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+    T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+  std::optional<parser::Designator> arrayElementDesignator =
+      std::visit(llvm::makeVisitor(
+                     [&](parser::Name &functionReferenceName)
+                         -> std::optional<parser::Designator> {
+                       if (currentTemp_->getOriginalName().symbol ==
+                           functionReferenceName.symbol) {
+                         return funcRef.ConvertToArrayElementRef();
+                       }
+                       return std::nullopt;
+                     },
+                     [&](auto &) -> std::optional<parser::Designator> {
+                       return std::nullopt;
+                     }),
+          procedureDesignator.u);
+
+  if (arrayElementDesignator.has_value()) {
+    if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+      node = T{
+          common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+    }
+  }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+    parser::CharBlock source) {
+  parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+  tempVariableName.symbol = currentTemp_->getTempSymbol();
+  parser::Designator tempDesignator{
+      parser::DataRef{std::move(tempVariableName)}};
+  tempDesignator.source = source;
+  return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+    parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+    parser::Block &block, ReplacementTemp &temp) {
+  parser::OmpBeginLoopDirective &ompBeginLoop{
+      std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+  std::list<parser::OmpClause> &clauseList{
+      std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
+  bool rewrittenArrayElement{false};
+
+  for (auto iter{clauseList.begin()}; iter != clauseList.end(); ++iter) {
+    rewrittenArrayElement = std::visit(
+        llvm::makeVisitor(
+            [&](parser::OmpClause::Reduction &clause) -> bool {
+              std::list<parser::OmpObject> &objectList =
+                  std::get<parser::OmpObjectList>(clause.v.t).v;
+
+              bool rewritten{false};
+              for (auto object{objectList.begin()}; object != objectList.end();
+                  ++object) {
+                rewritten = std::visit(
+                    llvm::makeVisitor(
+                        [&](parser::Designator &designator) -> bool {
+                          if (!identifyArrayElementReduced(designator, temp)) {
+                            return false;
+                          }
+                          if (temp.isSectionTriplet()) {
+                            return false;
+                          }
+
+                          reassignmentInsertionPoint_ =
+                              it != block.end() ? it : block.end();
----------------
Stylie777 wrote:
Done
https://github.com/llvm/llvm-project/pull/163940
    
    
More information about the flang-commits
mailing list