[llvm-branch-commits] [flang] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (PR #81623)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Feb 13 08:32:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

…essor

Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects.

Leave `Map` unchanged, because it will require more changes for it to work.

---

Patch is 51.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81623.diff


2 Files Affected:

- (modified) flang/include/flang/Evaluate/tools.h (+23) 
- (modified) flang/lib/Lower/OpenMP.cpp (+305-327) 


``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
   }
 }
 
+struct ExtractSubstringHelper {
+  template <typename T> static std::optional<Substring> visit(T &&) {
+    return std::nullopt;
+  }
+
+  static std::optional<Substring> visit(const Substring &e) { return e; }
+
+  template <typename T>
+  static std::optional<Substring> visit(const Designator<T> &e) {
+    return std::visit([](auto &&s) { return visit(s); }, e.u);
+  }
+
+  template <typename T>
+  static std::optional<Substring> visit(const Expr<T> &e) {
+    return std::visit([](auto &&s) { return visit(s); }, e.u);
+  }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+  return ExtractSubstringHelper::visit(x);
+}
+
 // If an expression is simply a whole symbol data designator,
 // extract and return that symbol, else null.
 template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index d7a93db15a4bb8..4b21ab934c9393 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
   return sym;
 }
 
-static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
-                          Fortran::lower::AbstractConverter &converter,
-                          llvm::SmallVectorImpl<mlir::Value> &operands) {
+static void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+                           Fortran::lower::AbstractConverter &converter,
+                           llvm::SmallVectorImpl<mlir::Value> &operands) {
   auto addOperands = [&](Fortran::lower::SymbolRef sym) {
     const mlir::Value variable = converter.getSymbolAddress(sym);
     if (variable) {
@@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
   }
 }
 
-static void gatherFuncAndVarSyms(
-    const Fortran::parser::OmpObjectList &objList,
-    mlir::omp::DeclareTargetCaptureClause clause,
-    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
-  for (const Fortran::parser::OmpObject &ompObject : objList.v) {
-    Fortran::common::visit(
-        Fortran::common::visitors{
-            [&](const Fortran::parser::Designator &designator) {
-              if (const Fortran::parser::Name *name =
-                      Fortran::semantics::getDesignatorNameIfDataRef(
-                          designator)) {
-                symbolAndClause.emplace_back(clause, *name->symbol);
-              }
-            },
-            [&](const Fortran::parser::Name &name) {
-              symbolAndClause.emplace_back(clause, *name.symbol);
-            }},
-        ompObject.u);
-  }
-}
-
 static Fortran::lower::pft::Evaluation *
 getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
   // Return the Evaluation of the innermost collapsed loop, or the current one
@@ -1257,6 +1236,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses,
 }
 } // namespace omp
 
+static void genObjectList(const omp::ObjectList &objects,
+                          Fortran::lower::AbstractConverter &converter,
+                          llvm::SmallVectorImpl<mlir::Value> &operands) {
+  for (const omp::Object &object : objects) {
+    const Fortran::semantics::Symbol *sym = object.sym;
+    assert(sym && "Expected Symbol");
+    if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+      operands.push_back(variable);
+    } else {
+      if (const auto *details =
+              sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+        operands.push_back(converter.getSymbolAddress(details->symbol()));
+        converter.copySymbolBinding(details->symbol(), *sym);
+      }
+    }
+  }
+}
+
+static void gatherFuncAndVarSyms(
+    const omp::ObjectList &objects,
+    mlir::omp::DeclareTargetCaptureClause clause,
+    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+  for (const omp::Object &object : objects)
+    symbolAndClause.emplace_back(clause, *object.sym);
+}
+
 //===----------------------------------------------------------------------===//
 // DataSharingProcessor
 //===----------------------------------------------------------------------===//
@@ -1718,9 +1723,8 @@ class ClauseProcessor {
                      llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
   bool
   processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
-  bool
-  processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
-            mlir::Value &result) const;
+  bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+                 mlir::Value &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
@@ -1815,6 +1819,26 @@ class ClauseProcessor {
   /// if at least one instance was found.
   template <typename T>
   bool findRepeatableClause(
+      std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+          callbackFn) const {
+    bool found = false;
+    ClauseIterator nextIt, endIt = clauses.end();
+    for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+      nextIt = findClause<T>(it, endIt);
+
+      if (nextIt != endIt) {
+        callbackFn(std::get<T>(nextIt->u), nextIt->source);
+        found = true;
+        ++nextIt;
+      }
+    }
+    return found;
+  }
+
+  /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+  /// if at least one instance was found.
+  template <typename T>
+  bool findRepeatableClause2(
       std::function<void(const T *, const Fortran::parser::CharBlock &source)>
           callbackFn) const {
     bool found = false;
@@ -1880,9 +1904,9 @@ class ReductionProcessor {
     IEOR
   };
   static ReductionIdentifier
-  getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+  getReductionType(const omp::clause::ProcedureDesignator &pd) {
     auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
-                       getRealName(pd).ToString())
+                       getRealName(pd.v.sym).ToString())
                        .Case("max", ReductionIdentifier::MAX)
                        .Case("min", ReductionIdentifier::MIN)
                        .Case("iand", ReductionIdentifier::IAND)
@@ -1894,35 +1918,33 @@ class ReductionProcessor {
   }
 
   static ReductionIdentifier getReductionType(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+      omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
     switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case omp::clause::DefinedOperator::IntrinsicOperator::Add:
       return ReductionIdentifier::ADD;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+    case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
       return ReductionIdentifier::SUBTRACT;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
       return ReductionIdentifier::MULTIPLY;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case omp::clause::DefinedOperator::IntrinsicOperator::AND:
       return ReductionIdentifier::AND;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
       return ReductionIdentifier::EQV;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case omp::clause::DefinedOperator::IntrinsicOperator::OR:
       return ReductionIdentifier::OR;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+    case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
       return ReductionIdentifier::NEQV;
     default:
       llvm_unreachable("unexpected intrinsic operator in reduction");
     }
   }
 
-  static bool supportedIntrinsicProcReduction(
-      const Fortran::parser::ProcedureDesignator &pd) {
-    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-    assert(name && "Invalid Reduction Intrinsic.");
-    if (!name->symbol->GetUltimate().attrs().test(
-            Fortran::semantics::Attr::INTRINSIC))
+  static bool
+  supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) {
+    Fortran::semantics::Symbol *sym = pd.v.sym;
+    if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
       return false;
-    auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+    auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
                        .Case("max", true)
                        .Case("min", true)
                        .Case("iand", true)
@@ -1933,15 +1955,13 @@ class ReductionProcessor {
   }
 
   static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::Name *name) {
-    return name->symbol->GetUltimate().name();
+  getRealName(const Fortran::semantics::Symbol *symbol) {
+    return symbol->GetUltimate().name();
   }
 
   static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::ProcedureDesignator &pd) {
-    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-    assert(name && "Invalid Reduction Intrinsic.");
-    return getRealName(name);
+  getRealName(const omp::clause::ProcedureDesignator &pd) {
+    return getRealName(pd.v.sym);
   }
 
   static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
@@ -1951,25 +1971,25 @@ class ReductionProcessor {
         .str();
   }
 
-  static std::string getReductionName(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty) {
+  static std::string
+  getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+                   mlir::Type ty) {
     std::string reductionName;
 
     switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case omp::clause::DefinedOperator::IntrinsicOperator::Add:
       reductionName = "add_reduction";
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
       reductionName = "multiply_reduction";
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case omp::clause::DefinedOperator::IntrinsicOperator::AND:
       return "and_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
       return "eqv_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case omp::clause::DefinedOperator::IntrinsicOperator::OR:
       return "or_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+    case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
       return "neqv_reduction";
     default:
       reductionName = "other_reduction";
@@ -2213,7 +2233,7 @@ class ReductionProcessor {
   static void
   addReductionDecl(mlir::Location currentLocation,
                    Fortran::lower::AbstractConverter &converter,
-                   const Fortran::parser::OmpReductionClause &reduction,
+                   const omp::clause::Reduction &reduction,
                    llvm::SmallVectorImpl<mlir::Value> &reductionVars,
                    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
                    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -2221,13 +2241,12 @@ class ReductionProcessor {
     fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
     mlir::omp::ReductionDeclareOp decl;
     const auto &redOperator{
-        std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
-    const auto &objectList{
-        std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+        std::get<omp::clause::ReductionOperator>(reduction.t)};
+    const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
     if (const auto &redDefinedOp =
-            std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+            std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
       const auto &intrinsicOp{
-          std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+          std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
               redDefinedOp->u)};
       ReductionIdentifier redId = getReductionType(intrinsicOp);
       switch (redId) {
@@ -2243,10 +2262,41 @@ class ReductionProcessor {
              "Reduction of some intrinsic operators is not supported");
         break;
       }
-      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-        if (const auto *name{
-                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+      for (const omp::Object &object : objectList) {
+        if (const Fortran::semantics::Symbol *symbol = object.sym) {
+          if (reductionSymbols)
+            reductionSymbols->push_back(symbol);
+          mlir::Value symVal = converter.getSymbolAddress(*symbol);
+          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+            symVal = declOp.getBase();
+          mlir::Type redType =
+              symVal.getType().cast<fir::ReferenceType>().getEleTy();
+          reductionVars.push_back(symVal);
+          if (redType.isa<fir::LogicalType>())
+            decl = createReductionDecl(
+                firOpBuilder,
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+                redType, currentLocation);
+          else if (redType.isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(firOpBuilder,
+                                       getReductionName(intrinsicOp, redType),
+                                       redId, redType, currentLocation);
+          } else {
+            TODO(currentLocation, "Reduction of some types is not supported");
+          }
+          reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+              firOpBuilder.getContext(), decl.getSymName()));
+        }
+      }
+    } else if (const auto *reductionIntrinsic =
+                   std::get_if<omp::clause::ProcedureDesignator>(
+                       &redOperator.u)) {
+      if (ReductionProcessor::supportedIntrinsicProcReduction(
+              *reductionIntrinsic)) {
+        ReductionProcessor::ReductionIdentifier redId =
+            ReductionProcessor::getReductionType(*reductionIntrinsic);
+        for (const omp::Object &object : objectList) {
+          if (const Fortran::semantics::Symbol *symbol = object.sym) {
             if (reductionSymbols)
               reductionSymbols->push_back(symbol);
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -2255,55 +2305,18 @@ class ReductionProcessor {
             mlir::Type redType =
                 symVal.getType().cast<fir::ReferenceType>().getEleTy();
             reductionVars.push_back(symVal);
-            if (redType.isa<fir::LogicalType>())
-              decl = createReductionDecl(
-                  firOpBuilder,
-                  getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                  redId, redType, currentLocation);
-            else if (redType.isIntOrIndexOrFloat()) {
-              decl = createReductionDecl(firOpBuilder,
-                                         getReductionName(intrinsicOp, redType),
-                                         redId, redType, currentLocation);
-            } else {
-              TODO(currentLocation, "Reduction of some types is not supported");
-            }
+            assert(redType.isIntOrIndexOrFloat() &&
+                   "Unsupported reduction type");
+            decl = createReductionDecl(
+                firOpBuilder,
+                getReductionName(getRealName(*reductionIntrinsic).ToString(),
+                                 redType),
+                redId, redType, currentLocation);
             reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
                 firOpBuilder.getContext(), decl.getSymName()));
           }
         }
       }
-    } else if (const auto *reductionIntrinsic =
-                   std::get_if<Fortran::parser::ProcedureDesignator>(
-                       &redOperator.u)) {
-      if (ReductionProcessor::supportedIntrinsicProcReduction(
-              *reductionIntrinsic)) {
-        ReductionProcessor::ReductionIdentifier redId =
-            ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-              if (reductionSymbols)
-                reductionSymbols->push_back(symbol);
-              mlir::Value symVal = converter.getSymbolAddress(*symbol);
-              if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-                symVal = declOp.getBase();
-              mlir::Type redType =
-                  symVal.getType().cast<fir::ReferenceType>().getEleTy();
-              reductionVars.push_back(symVal);
-              assert(redType.isIntOrIndexOrFloat() &&
-                     "Unsupported reduction type");
-              decl = createReductionDecl(
-                  firOpBuilder,
-                  getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                   redType),
-                  redId, redType, currentLocation);
-              reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-                  firOpBuilder.getContext(), decl.getSymName()));
-            }
-          }
-        }
-      }
     }
   }
 };
@@ -2365,7 +2378,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
 
 static void
 genAllocateClause(Fortran::lower::AbstractConverter &converter,
-                  const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+                  const omp::clause::Allocate &clause,
                   llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
                   llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -2373,21 +2386,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
 
   mlir::Value allocatorOperand;
-  const Fortran::parser::OmpObjectList &ompObjectList =
-      std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
-  const auto &allocateModifier = std::get<
-      std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
-      ompAllocateClause.t);
+  const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+  const auto &modifier =
+      std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
 
   // If the allocate modifier is present, check if we only use the allocator
   // submodifier.  ALIGN in this context is unimplemented
   const bool onlyAllocator =
-      allocateModifier &&
-      std::holds_alternative<
-          Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
-          allocateModifier->u);
+      modifier &&
+      std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+          modifier->u);
 
-  if (allocateModifier && !onlyAllocator) {
+  if (modifier && !onlyAllocator) {
     TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
   }
 
@@ -2395,20 +2405,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
   // to list of allocators, otherwise, add default allocator to
   // list of allocators.
   if (onlyAllocator) {
-    const auto &allocatorValue = std::get<
-        Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
-        allocateModifier->u);
-    allocatorOperand = fir...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81623


More information about the llvm-branch-commits mailing list