[llvm-branch-commits] [flang] [llvm] [flang][OpenMP] Main splitting functionality dev-complete (PR #82003)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 16 07:50:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload are set after lowering in bbc. The flang frontend sets them before lowering, making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion

---

Patch is 63.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82003.diff


6 Files Affected:

- (modified) flang/lib/Lower/OpenMP.cpp (+1034-10) 
- (modified) flang/tools/bbc/bbc.cpp (+1-1) 
- (modified) llvm/include/llvm/Frontend/Directive/DirectiveBase.td (+4) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+54-6) 
- (modified) llvm/include/llvm/TableGen/DirectiveEmitter.h (+4) 
- (modified) llvm/utils/TableGen/DirectiveEmitter.cpp (+77) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e45ab842b15556..ed6a0063848b18 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Support/CommandLine.h"
@@ -48,6 +49,29 @@ using DeclareTargetCapturePair =
 // Common helper functions
 //===----------------------------------------------------------------------===//
 
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharing() {
+  static llvm::omp::Directive worksharing[] = {
+      llvm::omp::Directive::OMPD_do,     llvm::omp::Directive::OMPD_for,
+      llvm::omp::Directive::OMPD_scope,  llvm::omp::Directive::OMPD_sections,
+      llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare,
+  };
+  return worksharing;
+}
+
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharingLoop() {
+  static llvm::omp::Directive worksharingLoop[] = {
+      llvm::omp::Directive::OMPD_do,
+      llvm::omp::Directive::OMPD_for,
+  };
+  return worksharingLoop;
+}
+
+static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) {
+  if (mlir::Attribute verAttr = mod->getAttr("omp.version"))
+    return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion();
+  llvm_unreachable("Exoecting OpenMP version attribute in module");
+}
+
 static Fortran::semantics::Symbol *
 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
   Fortran::semantics::Symbol *sym = nullptr;
@@ -166,6 +190,15 @@ struct SymDsgExtractor {
     return t;
   }
 
+  static semantics::Symbol *symbol_addr(const evaluate::SymbolRef &ref) {
+    // Symbols cannot be created after semantic checks, so all symbol
+    // pointers that are non-null must point to one of those pre-existing
+    // objects. Throughout the code, symbols are often pointed to by
+    // non-const pointers, so there is no harm in casting the constness
+    // away.
+    return const_cast<semantics::Symbol *>(&ref.get());
+  }
+
   template <typename T> //
   static SymDsg visit(T &&) {
     // Use this to see missing overloads:
@@ -175,19 +208,12 @@ struct SymDsgExtractor {
 
   template <typename T> //
   static SymDsg visit(const evaluate::Designator<T> &e) {
-    // Symbols cannot be created after semantic checks, so all symbol
-    // pointers that are non-null must point to one of those pre-existing
-    // objects. Throughout the code, symbols are often pointed to by
-    // non-const pointers, so there is no harm in casting the constness
-    // away.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetLastSymbol()),
+    return std::make_tuple(symbol_addr(*e.GetLastSymbol()),
                            evaluate::AsGenericExpr(AsRvalueRef(e)));
   }
 
   static SymDsg visit(const evaluate::ProcedureDesignator &e) {
-    // See comment above regarding const_cast.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetSymbol()),
-                           std::nullopt);
+    return std::make_tuple(symbol_addr(*e.GetSymbol()), std::nullopt);
   }
 
   template <typename T> //
@@ -313,6 +339,42 @@ std::optional<U> maybeApply(F &&func, const std::optional<T> &inp) {
   return std::move(func(*inp));
 }
 
+std::optional<Object>
+getBaseObject(const Object &object,
+              Fortran::semantics::SemanticsContext &semaCtx) {
+  // If it's just the symbol, then there is no base.
+  if (!object.dsg)
+    return std::nullopt;
+
+  auto maybeRef = evaluate::ExtractDataRef(*object.dsg);
+  if (!maybeRef)
+    return std::nullopt;
+
+  evaluate::DataRef ref = *maybeRef;
+
+  if (std::get_if<evaluate::SymbolRef>(&ref.u)) {
+    return std::nullopt;
+  } else if (auto *comp = std::get_if<evaluate::Component>(&ref.u)) {
+    const evaluate::DataRef &base = comp->base();
+    return Object{SymDsgExtractor::symbol_addr(base.GetLastSymbol()),
+                  evaluate::AsGenericExpr(SymDsgExtractor::AsRvalueRef(base))};
+  } else if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) {
+    const evaluate::NamedEntity &base = arr->base();
+    evaluate::ExpressionAnalyzer ea{semaCtx};
+    if (auto *comp = base.UnwrapComponent()) {
+      return Object{
+          SymDsgExtractor::symbol_addr(comp->symbol()),
+          ea.Designate(evaluate::DataRef{SymDsgExtractor::AsRvalueRef(*comp)})};
+    } else if (base.UnwrapSymbolRef()) {
+      return std::nullopt;
+    }
+  } else {
+    assert(std::holds_alternative<evaluate::CoarrayRef>(ref.u));
+    llvm_unreachable("Coarray reference not supported at the moment");
+  }
+  return std::nullopt;
+}
+
 namespace clause {
 #ifdef EMPTY_CLASS
 #undef EMPTY_CLASS
@@ -1220,11 +1282,18 @@ struct Clause {
   clause::UnionOfAllClauses u;
 };
 
+template <typename Specific>
+Clause makeClause(llvm::omp::Clause id, Specific &&specific,
+                  parser::CharBlock source = {}) {
+  return Clause{source, id, specific};
+}
+
 Clause makeClause(const Fortran::parser::OmpClause &cls,
                   semantics::SemanticsContext &semaCtx) {
   return std::visit(
       [&](auto &&s) {
-        return Clause{cls.source, getClauseId(cls), clause::make(s, semaCtx)};
+        return makeClause(getClauseId(cls), clause::make(s, semaCtx),
+                          cls.source);
       },
       cls.u);
 }
@@ -1263,6 +1332,957 @@ static void gatherFuncAndVarSyms(
     symbolAndClause.emplace_back(clause, *object.sym);
 }
 
+//===----------------------------------------------------------------------===//
+// Directive decomposition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct DirectiveInfo {
+  llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown;
+  llvm::SmallVector<const omp::Clause *> clauses;
+};
+
+struct CompositeInfo {
+  CompositeInfo(const mlir::ModuleOp &modOp,
+                Fortran::semantics::SemanticsContext &semaCtx,
+                Fortran::lower::pft::Evaluation &ev,
+                llvm::omp::Directive compDir,
+                const Fortran::parser::OmpClauseList &clauseList);
+  using ClauseSet = std::set<const omp::Clause *>;
+
+  bool split();
+  void addClauseSymbols(const omp::Clause &clause);
+
+  DirectiveInfo *findDirective(llvm::omp::Directive dirId) {
+    for (DirectiveInfo &dir : leafs) {
+      if (dir.id == dirId)
+        return &dir;
+    }
+    return nullptr;
+  }
+  ClauseSet *findClauses(const omp::Object &object) {
+    if (auto found = syms.find(object.sym); found != syms.end())
+      return &found->second;
+    return nullptr;
+  }
+
+  Fortran::semantics::SemanticsContext &semaCtx;
+  const mlir::ModuleOp &mod;
+  Fortran::lower::pft::Evaluation &eval;
+
+  llvm::SmallVector<DirectiveInfo> leafs; // Ordered outer to inner.
+  omp::List<omp::Clause> clauses;
+  llvm::DenseMap<const Fortran::semantics::Symbol *, ClauseSet> syms;
+  llvm::DenseSet<const Fortran::semantics::Symbol *> mapBases;
+  // Storage for newly created clauses. Beware of invalidating addresses.
+  std::list<omp::Clause> extras;
+
+private:
+  void addClauseSymsToMap(const omp::Object &object, const omp::Clause *);
+  void addClauseSymsToMap(const omp::ObjectList &objects, const omp::Clause *);
+  void addClauseSymsToMap(const omp::SomeExpr &item, const omp::Clause *);
+  void addClauseSymsToMap(const omp::clause::Map &item, const omp::Clause *);
+
+  template <typename T>
+  void addClauseSymsToMap(const std::optional<T> &item, const omp::Clause *);
+  template <typename T>
+  void addClauseSymsToMap(const omp::List<T> &item, const omp::Clause *);
+  template <typename... T, size_t... Is>
+  void addClauseSymsToMap(const std::tuple<T...> &item, const omp::Clause *,
+                          std::index_sequence<Is...> = {});
+  template <typename T,
+            std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+
+  // Apply a clause to the only directive that allows it. If there are no
+  // directives that allow it, or if there is more that one, do not apply
+  // anything and return false, otherwise return true.
+  bool applyToUnique(const omp::Clause *node);
+
+  // Apply a clause to the first directive in given range that allows it.
+  // If such a directive does not exist, return false, otherwise return true.
+  template <typename Iterator>
+  bool applyToFirst(const omp::Clause *node, const mlir::ModuleOp &mod,
+                    llvm::iterator_range<Iterator> range);
+
+  // Apply a clause to the innermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToInnermost(const omp::Clause *node);
+
+  // Apply a clause to the outermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToOutermost(const omp::Clause *node);
+
+  template <typename Predicate>
+  bool applyIf(const omp::Clause *node, Predicate shouldApply);
+
+  bool applyToAll(const omp::Clause *node);
+
+  template <typename Clause>
+  bool applyClause(Clause &&clause, const omp::Clause *node);
+
+  bool applyClause(const omp::clause::Collapse &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Private &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Firstprivate &clause,
+                   const omp::Clause *);
+  bool applyClause(const omp::clause::Lastprivate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Shared &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Default &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::ThreadLimit &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Order &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Allocate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Reduction &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::If &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Linear &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Nowait &clause, const omp::Clause *);
+};
+} // namespace
+
+CompositeInfo::CompositeInfo(const mlir::ModuleOp &modOp,
+                             Fortran::semantics::SemanticsContext &semaCtx,
+                             Fortran::lower::pft::Evaluation &ev,
+                             llvm::omp::Directive compDir,
+                             const Fortran::parser::OmpClauseList &clauseList)
+    : semaCtx(semaCtx), mod(modOp), eval(ev),
+      clauses(omp::makeList(clauseList, semaCtx)) {
+  for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir))
+    leafs.push_back(DirectiveInfo{dir});
+
+  for (const omp::Clause &clause : clauses)
+    addClauseSymsToMap(clause, &clause);
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const DirectiveInfo &dirInfo) {
+  os << llvm::omp::getOpenMPDirectiveName(dirInfo.id);
+  for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) {
+    os << (index == 0 ? '\t' : ' ');
+    os << llvm::omp::getOpenMPClauseName(clause->id);
+  }
+  return os;
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const CompositeInfo &compInfo) {
+  for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs))
+    os << "leaf[" << index << "]: " << dirInfo << '\n';
+
+  os << "syms:\n";
+  for (const auto &[sym, clauses] : compInfo.syms) {
+    os << *sym << " -> {";
+    for (const auto *clause : clauses)
+      os << ' ' << llvm::omp::getOpenMPClauseName(clause->id);
+    os << " }\n";
+  }
+  os << "mapBases: {";
+  for (const auto &sym : compInfo.mapBases)
+    os << ' ' << *sym;
+  os << " }\n";
+  return os;
+}
+
+namespace detail {
+template <typename Container, typename Predicate>
+typename std::remove_reference_t<Container>::iterator
+find_unique(Container &&container, Predicate &&pred) {
+  auto first = std::find_if(container.begin(), container.end(), pred);
+  if (first == container.end())
+    return first;
+  auto second = std::find_if(std::next(first), container.end(), pred);
+  if (second == container.end())
+    return first;
+  return container.end();
+}
+} // namespace detail
+
+static Fortran::semantics::Symbol *
+getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) {
+  return eval.visit(Fortran::common::visitors{
+      [&](const Fortran::parser::DoConstruct &doLoop) {
+        if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
+          using LoopControl = Fortran::parser::LoopControl;
+          if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
+            static_assert(
+                std::is_same_v<decltype(bounds->name),
+                               Fortran::parser::Scalar<Fortran::parser::Name>>);
+            return bounds->name.thing.symbol;
+          }
+        }
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+      [](auto &&) {
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+  });
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::Object &object,
+                                       const omp::Clause *node) {
+  syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::ObjectList &objects,
+                                       const omp::Clause *node) {
+  for (auto &object : objects)
+    syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::SomeExpr &expr,
+                                       const omp::Clause *node) {
+  // Nothing to do for expressions.
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::clause::Map &item,
+                                       const omp::Clause *node) {
+  auto &objects = std::get<omp::ObjectList>(item.t);
+  addClauseSymsToMap(objects, node);
+  for (auto &object : objects) {
+    if (auto base = omp::getBaseObject(object, semaCtx))
+      mapBases.insert(base->sym);
+  }
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const std::optional<T> &item,
+                                       const omp::Clause *node) {
+  if (item)
+    addClauseSymsToMap(*item, node);
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const omp::List<T> &item,
+                                       const omp::Clause *node) {
+  for (auto &s : item)
+    addClauseSymsToMap(s, node);
+}
+
+template <typename... T, size_t... Is>
+void CompositeInfo::addClauseSymsToMap(const std::tuple<T...> &item,
+                                       const omp::Clause *node,
+                                       std::index_sequence<Is...>) {
+  (void)node; // Silence strange warning from GCC.
+  (addClauseSymsToMap(std::get<Is>(item), node), ...);
+}
+
+template <typename T,
+          std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for enums.
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for an empty class.
+}
+
+template <
+    typename T,
+    std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  addClauseSymsToMap(item.v, node);
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  constexpr size_t tuple_size =
+      std::tuple_size_v<llvm::remove_cvref_t<decltype(item.t)>>;
+  addClauseSymsToMap(item.t, node, std::make_index_sequence<tuple_size>{});
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  std::visit([&](auto &&s) { addClauseSymsToMap(s, node); }, item.u);
+}
+
+#if 1
+// Apply a clause to the only directive that allows it. If there are no
+// directives that allow it, or if there is more that one, do not apply
+// anything and return false, otherwise return true.
+bool CompositeInfo::applyToUnique(const omp::Clause *node) {
+  uint32_t version = getOpenMPVersion(mod);
+  auto unique = detail::find_unique(leafs, [=](const auto &dirInfo) {
+    return llvm::omp::isAllowedClauseForDirective(dirInfo.id, node->id,
+                                                  version);
+  });
+
+  if (unique != leafs.end()) {
+    unique->clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the first directive in given range that allows it.
+// If such a directive does not exist, return false, otherwise return true.
+template <typename Iterator>
+bool CompositeInfo::applyToFirst(const omp::Clause *node,
+                                 const mlir::ModuleOp &mod,
+                                 llvm::iterator_range<Iterator> range) {
+  if (range.empty())
+    return false;
+
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : range) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    dir.clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the innermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToInnermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::reverse(leafs));
+}
+
+// Apply a clause to the outermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToOutermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::iterator_range(leafs));
+}
+
+template <typename Predicate>
+bool CompositeInfo::applyIf(const omp::Clause *node, Predicate shouldApply) {
+  bool applied = false;
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : leafs) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    if (!shouldApply(dir))
+      continue;
+    dir.clauses.push_back(node);
+    applied = true;
+  }
+
+  return applied;
+}
+
+bool CompositeInfo::applyToAll(const omp::Clause *node) {
+  return applyIf(node, [](auto) { return true; });
+}
+
+template <typename Clause>
+bool CompositeInfo::applyClause(Clause &&clause, const omp::Clause *node) {
+  // The default behavior is to find the unique directive to which the
+  // given clause may be applied. If there are no such directives, or
+  // if there are multiple ones, flag an error.
+  // From "OpenMP Application Programming Interface", Version 5.2:
+  // S Some clauses are permitted only on a single leaf construct of the
+  // S combined or composite construct, in which case the effect is as if
+  // S the clause is applied to that specific construct. (p339, 31-33)
+  if (applyToUnique(node))
+    return true;
+
+  return false;
+}
+
+// COLLAPSE
+bool CompositeInfo::applyClause(const omp::clause::Collapse &clause,
+                                const omp::Clause *node) {
+  // Apply COLLAPSE to the innermost directive. If it's not one that
+  // allows it flag an error.
+  if (!leafs.empty()) {
+    DirectiveInfo &last = leafs.back();
+    uint32_t version = getOpenMPVersion(mod);
+
+    if (llvm::omp::isAllowedClauseForDirectiv...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82003


More information about the llvm-branch-commits mailing list