[flang-commits] [flang] [flang][OpenMP] Improve reduction of Scalar ArrayElement types (PR #163940)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Mon Oct 20 06:09:03 PDT 2025


================
@@ -492,10 +572,468 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
   FixMisparsedUntaggedNamelistName(x);
 }
 
+void ReplacementTemp::createTempSymbol(
+    SourceName &source, Scope &scope, SemanticsContext &context) {
+  replacementTempSymbol_ =
+      const_cast<Scope &>(originalName_.symbol->owner()).FindSymbol(source);
+  replacementTempSymbol_->set_scope(
+      &const_cast<Scope &>(originalName_.symbol->owner()));
+  DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+  replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+  replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+    std::list<parser::SectionSubscript> &sectionSubscript) {
+  bool setSubscript{false};
+  auto visitLiteralConstant = [&](parser::LiteralConstant &literalConstant) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntLiteralConstant &intLiteralConstant) {
+                     originalSubscriptCharBlock_ =
+                         std::get<parser::CharBlock>(intLiteralConstant.t);
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        literalConstant.u);
+  };
+  auto visitIntExpr = [&](parser::IntExpr &intExpr) {
+    parser::Expr &expr = intExpr.thing.value();
+    std::visit(llvm::makeVisitor(
+                   [&](parser::LiteralConstant &literalConstant) {
+                     visitLiteralConstant(literalConstant);
+                   },
+                   [&](auto &) {}),
+        expr.u);
+  };
+  for (parser::SectionSubscript &subscript : sectionSubscript) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::IntExpr &intExpr) { visitIntExpr(intExpr); },
+                   [&](parser::SubscriptTriplet &triplet) {
+                     isSectionTriplet_ = true;
+                     setSubscript = true;
+                   },
+                   [&](auto &) {}),
+        subscript.u);
+    if (setSubscript) {
+      break;
+    }
+  }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+    parser::Block &block) {
+  if (block.empty()) {
+    return;
+  }
+
+  auto visitOpenMPLoopConstruct = [&](parser::OpenMPLoopConstruct &ompLoop,
+                                      parser::Block::iterator &it) {
+    ReplacementTemp temp;
+    if (!rewriteArrayElementToTemp(it, ompLoop, block, temp)) {
+      return;
+    }
+    auto &NestedConstruct =
+        std::get<std::optional<parser::NestedConstruct>>(ompLoop.t);
+    if (!NestedConstruct.has_value()) {
+      return;
+    }
+    if (parser::DoConstruct *
+        doConst{std::get_if<parser::DoConstruct>(&NestedConstruct.value())}) {
+      block_ = █
+      parser::Block &doBlock{std::get<parser::Block>(doConst->t)};
+      parser::Walk(doBlock, *this);
+      // Reset the current temp value so future
+      // iterations use their own version.
+      resetCurrentTemp();
+    }
+  };
+  auto visitOpenMPConstruct = [&](parser::OpenMPConstruct &ompConstruct,
+                                  parser::Block::iterator &it) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::OpenMPLoopConstruct &ompLoop) {
+                     visitOpenMPLoopConstruct(ompLoop, it);
+                   },
+                   [&](auto &) {}),
+        ompConstruct.u);
+  };
+  auto visitExecutableConstruct = [&](parser::ExecutableConstruct
+                                          &execConstruct,
+                                      parser::Block::iterator &it) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::OpenMPConstruct> &ompConstruct) {
+              visitOpenMPConstruct(ompConstruct.value(), it);
+            },
+            [&](auto &) {}),
+        execConstruct.u);
+  };
+  for (auto it{block.begin()}; it != block.end(); ++it) {
+    std::visit(llvm::makeVisitor(
+                   [&](parser::ExecutableConstruct &execConstruct) {
+                     visitExecutableConstruct(execConstruct, it);
+                   },
+                   [&](auto &) {}),
+        it->u);
+  }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+    parser::Designator &existingDesignator) {
+  bool matchesArrayElement{false};
+  std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+  auto visitName = [&](parser::Name &name, parser::ArrayElement &arrayElement) {
+    if (name.symbol->GetUltimate() ==
+        currentTemp_->getOriginalName().symbol->GetUltimate()) {
+      matchesArrayElement = true;
+      if (!currentTemp_->isArrayElementReassigned()) {
+        reassignTempValueToArrayElement(arrayElement);
+      }
+    }
+  };
+  auto visitArratElement = [&](parser::ArrayElement &arrayElement) {
+    subscripts = &arrayElement.subscripts;
+    std::visit(llvm::makeVisitor(
+                   [&](parser::Name &name) { visitName(name, arrayElement); },
+                   [](auto &) {}),
+        arrayElement.base.u);
+  };
+  auto visitDataRef = [&](parser::DataRef &dataRef) {
+    std::visit(
+        llvm::makeVisitor(
+            [&](common::Indirection<parser::ArrayElement> &arrayElement) {
+              visitArratElement(arrayElement.value());
+            },
+            [&](parser::Name &name) {
+              if (name.symbol->GetUltimate() ==
+                  currentTemp_->getOriginalName().symbol->GetUltimate()) {
+                matchesArrayElement = true;
+              }
+            },
+            [](auto &) {}),
+        dataRef.u);
+  };
+  std::visit(llvm::makeVisitor(
+                 [&](parser::DataRef &dataRef) { visitDataRef(dataRef); },
+                 [&](auto &) {}),
+      existingDesignator.u);
+
+  if (subscripts) {
+    bool foundSubscript{false};
+    auto visitIntLiteralConstant =
+        [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+      foundSubscript = true;
+      assert(currentTemp_ != nullptr &&
+          "Value for ReplacementTemp should have "
+          "been found");
+      if (std::get<parser::CharBlock>(intLiteralConstant.t) ==
+          currentTemp_->getOriginalSubscript()) {
+        return true;
+      }
+      return false;
+    };
+    auto visitLiteralConstant =
+        [&](parser::LiteralConstant &literalConstant) -> bool {
+      return std::visit(
+          llvm::makeVisitor(
+              [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+                return visitIntLiteralConstant(intLiteralConstant);
+              },
+              [](auto &) -> bool { return false; }),
+          literalConstant.u);
+    };
+    auto visitIntExpr = [&](parser::IntExpr &intExpr) -> bool {
+      parser::Expr &expr = intExpr.thing.value();
+      return std::visit(
+          llvm::makeVisitor(
+              [&](parser::LiteralConstant &literalConstant) -> bool {
+                return visitLiteralConstant(literalConstant);
+              },
+              [](auto &) -> bool { return false; }),
+          expr.u);
+    };
+    for (parser::SectionSubscript &subscript : *subscripts) {
+      matchesArrayElement =
+          std::visit(llvm::makeVisitor(
+                         [&](parser::IntExpr &intExpr) -> bool {
+                           return visitIntExpr(intExpr);
+                         },
+                         [](auto &) -> bool { return false; }),
+              subscript.u);
+      if (foundSubscript) {
+        break;
+      }
+    }
+  }
+  return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+    T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+  auto visitFunctionReferenceName = [&](parser::Name &functionReferenceName)
+      -> std::optional<parser::Designator> {
+    if (currentTemp_->getOriginalName().symbol ==
+        functionReferenceName.symbol) {
+      return funcRef.ConvertToArrayElementRef();
+    }
+    return std::nullopt;
+  };
+
+  auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+  std::optional<parser::Designator> arrayElementDesignator =
+      std::visit(llvm::makeVisitor(
+                     [&](parser::Name &functionReferenceName)
+                         -> std::optional<parser::Designator> {
+                       return visitFunctionReferenceName(functionReferenceName);
+                     },
+                     [&](auto &) -> std::optional<parser::Designator> {
+                       return std::nullopt;
+                     }),
+          procedureDesignator.u);
+
+  if (arrayElementDesignator.has_value()) {
+    if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+      node = T{
+          common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+    }
+  }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+    parser::CharBlock source) {
+  parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+  tempVariableName.symbol = currentTemp_->getTempSymbol();
+  parser::Designator tempDesignator{
+      parser::DataRef{std::move(tempVariableName)}};
+  tempDesignator.source = source;
+  return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+    parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+    parser::Block &block, ReplacementTemp &temp) {
+  parser::OmpBeginLoopDirective &ompBeginLoop{
+      std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+  std::list<parser::OmpClause> &clauseList{
+      std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
----------------
kparzysz wrote:

In case you didn't know, OpenMPLoopConstruct has a function "BeginDir" that returns a reference to the contained "OmpBeginLoopDirective".  The OmpBeginLoopDirective inherits from OmpDirectiveSpecification that has member function Clauses() that returns the list of clauses.
You could do
```
auto &clauseList{const_cast<std::list<parser::OmpClause>>(ompLoop).BeginDir().Clauses().v};
```
or
```
auto &clauseList{const_cast<parser::OmpClauseList>(ompLoop).BeginDir().Clauses()};
```
and use clauseList.v.

https://github.com/llvm/llvm-project/pull/163940


More information about the flang-commits mailing list