[flang-commits] [flang] [flang][OpenMP] Improve reduction of Scalar ArrayElement types (PR #163940)
Jack Styles via flang-commits
flang-commits at lists.llvm.org
Fri Oct 31 04:24:32 PDT 2025
================
@@ -492,10 +572,468 @@ void RewriteMutator::Post(parser::WriteStmt &x) {
FixMisparsedUntaggedNamelistName(x);
}
+void ReplacementTemp::createTempSymbol(
+ SourceName &source, Scope &scope, SemanticsContext &context) {
+ replacementTempSymbol_ =
+ const_cast<Scope &>(originalName_.symbol->owner()).FindSymbol(source);
+ replacementTempSymbol_->set_scope(
+ &const_cast<Scope &>(originalName_.symbol->owner()));
+ DeclTypeSpec *tempType = originalName_.symbol->GetUltimate().GetType();
+ replacementTempSymbol_->get<ObjectEntityDetails>().set_type(*tempType);
+ replacementTempSymbol_->flags().set(Symbol::Flag::CompilerCreated);
+}
+
+void ReplacementTemp::setOriginalSubscriptInt(
+ std::list<parser::SectionSubscript> §ionSubscript) {
+ bool setSubscript{false};
+ auto visitLiteralConstant = [&](parser::LiteralConstant &literalConstant) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::IntLiteralConstant &intLiteralConstant) {
+ originalSubscriptCharBlock_ =
+ std::get<parser::CharBlock>(intLiteralConstant.t);
+ setSubscript = true;
+ },
+ [&](auto &) {}),
+ literalConstant.u);
+ };
+ auto visitIntExpr = [&](parser::IntExpr &intExpr) {
+ parser::Expr &expr = intExpr.thing.value();
+ std::visit(llvm::makeVisitor(
+ [&](parser::LiteralConstant &literalConstant) {
+ visitLiteralConstant(literalConstant);
+ },
+ [&](auto &) {}),
+ expr.u);
+ };
+ for (parser::SectionSubscript &subscript : sectionSubscript) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::IntExpr &intExpr) { visitIntExpr(intExpr); },
+ [&](parser::SubscriptTriplet &triplet) {
+ isSectionTriplet_ = true;
+ setSubscript = true;
+ },
+ [&](auto &) {}),
+ subscript.u);
+ if (setSubscript) {
+ break;
+ }
+ }
+}
+
+void RewriteOmpReductionArrayElements::rewriteReductionArrayElementToTemp(
+ parser::Block &block) {
+ if (block.empty()) {
+ return;
+ }
+
+ auto visitOpenMPLoopConstruct = [&](parser::OpenMPLoopConstruct &ompLoop,
+ parser::Block::iterator &it) {
+ ReplacementTemp temp;
+ if (!rewriteArrayElementToTemp(it, ompLoop, block, temp)) {
+ return;
+ }
+ auto &NestedConstruct =
+ std::get<std::optional<parser::NestedConstruct>>(ompLoop.t);
+ if (!NestedConstruct.has_value()) {
+ return;
+ }
+ if (parser::DoConstruct *
+ doConst{std::get_if<parser::DoConstruct>(&NestedConstruct.value())}) {
+ block_ = █
+ parser::Block &doBlock{std::get<parser::Block>(doConst->t)};
+ parser::Walk(doBlock, *this);
+ // Reset the current temp value so future
+ // iterations use their own version.
+ resetCurrentTemp();
+ }
+ };
+ auto visitOpenMPConstruct = [&](parser::OpenMPConstruct &ompConstruct,
+ parser::Block::iterator &it) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::OpenMPLoopConstruct &ompLoop) {
+ visitOpenMPLoopConstruct(ompLoop, it);
+ },
+ [&](auto &) {}),
+ ompConstruct.u);
+ };
+ auto visitExecutableConstruct = [&](parser::ExecutableConstruct
+ &execConstruct,
+ parser::Block::iterator &it) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](common::Indirection<parser::OpenMPConstruct> &ompConstruct) {
+ visitOpenMPConstruct(ompConstruct.value(), it);
+ },
+ [&](auto &) {}),
+ execConstruct.u);
+ };
+ for (auto it{block.begin()}; it != block.end(); ++it) {
+ std::visit(llvm::makeVisitor(
+ [&](parser::ExecutableConstruct &execConstruct) {
+ visitExecutableConstruct(execConstruct, it);
+ },
+ [&](auto &) {}),
+ it->u);
+ }
+}
+
+bool RewriteOmpReductionArrayElements::isMatchingArrayElement(
+ parser::Designator &existingDesignator) {
+ bool matchesArrayElement{false};
+ std::list<parser::SectionSubscript> *subscripts{nullptr};
+
+ auto visitName = [&](parser::Name &name, parser::ArrayElement &arrayElement) {
+ if (name.symbol->GetUltimate() ==
+ currentTemp_->getOriginalName().symbol->GetUltimate()) {
+ matchesArrayElement = true;
+ if (!currentTemp_->isArrayElementReassigned()) {
+ reassignTempValueToArrayElement(arrayElement);
+ }
+ }
+ };
+ auto visitArratElement = [&](parser::ArrayElement &arrayElement) {
+ subscripts = &arrayElement.subscripts;
+ std::visit(llvm::makeVisitor(
+ [&](parser::Name &name) { visitName(name, arrayElement); },
+ [](auto &) {}),
+ arrayElement.base.u);
+ };
+ auto visitDataRef = [&](parser::DataRef &dataRef) {
+ std::visit(
+ llvm::makeVisitor(
+ [&](common::Indirection<parser::ArrayElement> &arrayElement) {
+ visitArratElement(arrayElement.value());
+ },
+ [&](parser::Name &name) {
+ if (name.symbol->GetUltimate() ==
+ currentTemp_->getOriginalName().symbol->GetUltimate()) {
+ matchesArrayElement = true;
+ }
+ },
+ [](auto &) {}),
+ dataRef.u);
+ };
+ std::visit(llvm::makeVisitor(
+ [&](parser::DataRef &dataRef) { visitDataRef(dataRef); },
+ [&](auto &) {}),
+ existingDesignator.u);
+
+ if (subscripts) {
+ bool foundSubscript{false};
+ auto visitIntLiteralConstant =
+ [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+ foundSubscript = true;
+ assert(currentTemp_ != nullptr &&
+ "Value for ReplacementTemp should have "
+ "been found");
+ if (std::get<parser::CharBlock>(intLiteralConstant.t) ==
+ currentTemp_->getOriginalSubscript()) {
+ return true;
+ }
+ return false;
+ };
+ auto visitLiteralConstant =
+ [&](parser::LiteralConstant &literalConstant) -> bool {
+ return std::visit(
+ llvm::makeVisitor(
+ [&](parser::IntLiteralConstant &intLiteralConstant) -> bool {
+ return visitIntLiteralConstant(intLiteralConstant);
+ },
+ [](auto &) -> bool { return false; }),
+ literalConstant.u);
+ };
+ auto visitIntExpr = [&](parser::IntExpr &intExpr) -> bool {
+ parser::Expr &expr = intExpr.thing.value();
+ return std::visit(
+ llvm::makeVisitor(
+ [&](parser::LiteralConstant &literalConstant) -> bool {
+ return visitLiteralConstant(literalConstant);
+ },
+ [](auto &) -> bool { return false; }),
+ expr.u);
+ };
+ for (parser::SectionSubscript &subscript : *subscripts) {
+ matchesArrayElement =
+ std::visit(llvm::makeVisitor(
+ [&](parser::IntExpr &intExpr) -> bool {
+ return visitIntExpr(intExpr);
+ },
+ [](auto &) -> bool { return false; }),
+ subscript.u);
+ if (foundSubscript) {
+ break;
+ }
+ }
+ }
+ return matchesArrayElement;
+}
+
+template <typename T>
+void RewriteOmpReductionArrayElements::processFunctionReference(
+ T &node, parser::CharBlock source, parser::FunctionReference &funcRef) {
+ auto visitFunctionReferenceName = [&](parser::Name &functionReferenceName)
+ -> std::optional<parser::Designator> {
+ if (currentTemp_->getOriginalName().symbol ==
+ functionReferenceName.symbol) {
+ return funcRef.ConvertToArrayElementRef();
+ }
+ return std::nullopt;
+ };
+
+ auto &[procedureDesignator, ArgSpecList] = funcRef.v.t;
+ std::optional<parser::Designator> arrayElementDesignator =
+ std::visit(llvm::makeVisitor(
+ [&](parser::Name &functionReferenceName)
+ -> std::optional<parser::Designator> {
+ return visitFunctionReferenceName(functionReferenceName);
+ },
+ [&](auto &) -> std::optional<parser::Designator> {
+ return std::nullopt;
+ }),
+ procedureDesignator.u);
+
+ if (arrayElementDesignator.has_value()) {
+ if (this->isMatchingArrayElement(arrayElementDesignator.value())) {
+ node = T{
+ common::Indirection<parser::Designator>{makeTempDesignator(source)}};
+ }
+ }
+}
+
+parser::Designator RewriteOmpReductionArrayElements::makeTempDesignator(
+ parser::CharBlock source) {
+ parser::Name tempVariableName{currentTemp_->getTempSymbol()->name()};
+ tempVariableName.symbol = currentTemp_->getTempSymbol();
+ parser::Designator tempDesignator{
+ parser::DataRef{std::move(tempVariableName)}};
+ tempDesignator.source = source;
+ return tempDesignator;
+}
+
+bool RewriteOmpReductionArrayElements::rewriteArrayElementToTemp(
+ parser::Block::iterator &it, parser::OpenMPLoopConstruct &ompLoop,
+ parser::Block &block, ReplacementTemp &temp) {
+ parser::OmpBeginLoopDirective &ompBeginLoop{
+ std::get<parser::OmpBeginLoopDirective>(ompLoop.t)};
+ std::list<parser::OmpClause> &clauseList{
+ std::get<std::optional<parser::OmpClauseList>>(ompBeginLoop.t)->v};
----------------
Stylie777 wrote:
Updated.
https://github.com/llvm/llvm-project/pull/163940
More information about the flang-commits
mailing list