[llvm-branch-commits] [flang] [llvm] [flang][OpenMP] Main splitting functionality dev-complete (PR #82003)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 16 07:49:58 PST 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/82003

[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload are set after lowering in bbc. The flang frontend sets them before lowering, making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion

>From ac2d8fd31c0a2b8f818a73a619496d5263c3ccb8 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 16 Jan 2024 16:40:47 -0600
Subject: [PATCH] [flang][OpenMP] Main splitting functionality dev-complete

[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D
will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload
are set after lowering in bbc. The flang frontend sets them before lowering,
making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion
---
 flang/lib/Lower/OpenMP.cpp                    | 1044 ++++++++++++++++-
 flang/tools/bbc/bbc.cpp                       |    2 +-
 .../llvm/Frontend/Directive/DirectiveBase.td  |    4 +
 llvm/include/llvm/Frontend/OpenMP/OMP.td      |   60 +-
 llvm/include/llvm/TableGen/DirectiveEmitter.h |    4 +
 llvm/utils/TableGen/DirectiveEmitter.cpp      |   77 ++
 6 files changed, 1174 insertions(+), 17 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e45ab842b15556..ed6a0063848b18 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Support/CommandLine.h"
@@ -48,6 +49,29 @@ using DeclareTargetCapturePair =
 // Common helper functions
 //===----------------------------------------------------------------------===//
 
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharing() {
+  static llvm::omp::Directive worksharing[] = {
+      llvm::omp::Directive::OMPD_do,     llvm::omp::Directive::OMPD_for,
+      llvm::omp::Directive::OMPD_scope,  llvm::omp::Directive::OMPD_sections,
+      llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare,
+  };
+  return worksharing;
+}
+
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharingLoop() {
+  static llvm::omp::Directive worksharingLoop[] = {
+      llvm::omp::Directive::OMPD_do,
+      llvm::omp::Directive::OMPD_for,
+  };
+  return worksharingLoop;
+}
+
+static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) {
+  if (mlir::Attribute verAttr = mod->getAttr("omp.version"))
+    return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion();
+  llvm_unreachable("Exoecting OpenMP version attribute in module");
+}
+
 static Fortran::semantics::Symbol *
 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
   Fortran::semantics::Symbol *sym = nullptr;
@@ -166,6 +190,15 @@ struct SymDsgExtractor {
     return t;
   }
 
+  static semantics::Symbol *symbol_addr(const evaluate::SymbolRef &ref) {
+    // Symbols cannot be created after semantic checks, so all symbol
+    // pointers that are non-null must point to one of those pre-existing
+    // objects. Throughout the code, symbols are often pointed to by
+    // non-const pointers, so there is no harm in casting the constness
+    // away.
+    return const_cast<semantics::Symbol *>(&ref.get());
+  }
+
   template <typename T> //
   static SymDsg visit(T &&) {
     // Use this to see missing overloads:
@@ -175,19 +208,12 @@ struct SymDsgExtractor {
 
   template <typename T> //
   static SymDsg visit(const evaluate::Designator<T> &e) {
-    // Symbols cannot be created after semantic checks, so all symbol
-    // pointers that are non-null must point to one of those pre-existing
-    // objects. Throughout the code, symbols are often pointed to by
-    // non-const pointers, so there is no harm in casting the constness
-    // away.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetLastSymbol()),
+    return std::make_tuple(symbol_addr(*e.GetLastSymbol()),
                            evaluate::AsGenericExpr(AsRvalueRef(e)));
   }
 
   static SymDsg visit(const evaluate::ProcedureDesignator &e) {
-    // See comment above regarding const_cast.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetSymbol()),
-                           std::nullopt);
+    return std::make_tuple(symbol_addr(*e.GetSymbol()), std::nullopt);
   }
 
   template <typename T> //
@@ -313,6 +339,42 @@ std::optional<U> maybeApply(F &&func, const std::optional<T> &inp) {
   return std::move(func(*inp));
 }
 
+std::optional<Object>
+getBaseObject(const Object &object,
+              Fortran::semantics::SemanticsContext &semaCtx) {
+  // If it's just the symbol, then there is no base.
+  if (!object.dsg)
+    return std::nullopt;
+
+  auto maybeRef = evaluate::ExtractDataRef(*object.dsg);
+  if (!maybeRef)
+    return std::nullopt;
+
+  evaluate::DataRef ref = *maybeRef;
+
+  if (std::get_if<evaluate::SymbolRef>(&ref.u)) {
+    return std::nullopt;
+  } else if (auto *comp = std::get_if<evaluate::Component>(&ref.u)) {
+    const evaluate::DataRef &base = comp->base();
+    return Object{SymDsgExtractor::symbol_addr(base.GetLastSymbol()),
+                  evaluate::AsGenericExpr(SymDsgExtractor::AsRvalueRef(base))};
+  } else if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) {
+    const evaluate::NamedEntity &base = arr->base();
+    evaluate::ExpressionAnalyzer ea{semaCtx};
+    if (auto *comp = base.UnwrapComponent()) {
+      return Object{
+          SymDsgExtractor::symbol_addr(comp->symbol()),
+          ea.Designate(evaluate::DataRef{SymDsgExtractor::AsRvalueRef(*comp)})};
+    } else if (base.UnwrapSymbolRef()) {
+      return std::nullopt;
+    }
+  } else {
+    assert(std::holds_alternative<evaluate::CoarrayRef>(ref.u));
+    llvm_unreachable("Coarray reference not supported at the moment");
+  }
+  return std::nullopt;
+}
+
 namespace clause {
 #ifdef EMPTY_CLASS
 #undef EMPTY_CLASS
@@ -1220,11 +1282,18 @@ struct Clause {
   clause::UnionOfAllClauses u;
 };
 
+template <typename Specific>
+Clause makeClause(llvm::omp::Clause id, Specific &&specific,
+                  parser::CharBlock source = {}) {
+  return Clause{source, id, specific};
+}
+
 Clause makeClause(const Fortran::parser::OmpClause &cls,
                   semantics::SemanticsContext &semaCtx) {
   return std::visit(
       [&](auto &&s) {
-        return Clause{cls.source, getClauseId(cls), clause::make(s, semaCtx)};
+        return makeClause(getClauseId(cls), clause::make(s, semaCtx),
+                          cls.source);
       },
       cls.u);
 }
@@ -1263,6 +1332,957 @@ static void gatherFuncAndVarSyms(
     symbolAndClause.emplace_back(clause, *object.sym);
 }
 
+//===----------------------------------------------------------------------===//
+// Directive decomposition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct DirectiveInfo {
+  llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown;
+  llvm::SmallVector<const omp::Clause *> clauses;
+};
+
+struct CompositeInfo {
+  CompositeInfo(const mlir::ModuleOp &modOp,
+                Fortran::semantics::SemanticsContext &semaCtx,
+                Fortran::lower::pft::Evaluation &ev,
+                llvm::omp::Directive compDir,
+                const Fortran::parser::OmpClauseList &clauseList);
+  using ClauseSet = std::set<const omp::Clause *>;
+
+  bool split();
+  void addClauseSymbols(const omp::Clause &clause);
+
+  DirectiveInfo *findDirective(llvm::omp::Directive dirId) {
+    for (DirectiveInfo &dir : leafs) {
+      if (dir.id == dirId)
+        return &dir;
+    }
+    return nullptr;
+  }
+  ClauseSet *findClauses(const omp::Object &object) {
+    if (auto found = syms.find(object.sym); found != syms.end())
+      return &found->second;
+    return nullptr;
+  }
+
+  Fortran::semantics::SemanticsContext &semaCtx;
+  const mlir::ModuleOp &mod;
+  Fortran::lower::pft::Evaluation &eval;
+
+  llvm::SmallVector<DirectiveInfo> leafs; // Ordered outer to inner.
+  omp::List<omp::Clause> clauses;
+  llvm::DenseMap<const Fortran::semantics::Symbol *, ClauseSet> syms;
+  llvm::DenseSet<const Fortran::semantics::Symbol *> mapBases;
+  // Storage for newly created clauses. Beware of invalidating addresses.
+  std::list<omp::Clause> extras;
+
+private:
+  void addClauseSymsToMap(const omp::Object &object, const omp::Clause *);
+  void addClauseSymsToMap(const omp::ObjectList &objects, const omp::Clause *);
+  void addClauseSymsToMap(const omp::SomeExpr &item, const omp::Clause *);
+  void addClauseSymsToMap(const omp::clause::Map &item, const omp::Clause *);
+
+  template <typename T>
+  void addClauseSymsToMap(const std::optional<T> &item, const omp::Clause *);
+  template <typename T>
+  void addClauseSymsToMap(const omp::List<T> &item, const omp::Clause *);
+  template <typename... T, size_t... Is>
+  void addClauseSymsToMap(const std::tuple<T...> &item, const omp::Clause *,
+                          std::index_sequence<Is...> = {});
+  template <typename T,
+            std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+
+  // Apply a clause to the only directive that allows it. If there are no
+  // directives that allow it, or if there is more that one, do not apply
+  // anything and return false, otherwise return true.
+  bool applyToUnique(const omp::Clause *node);
+
+  // Apply a clause to the first directive in given range that allows it.
+  // If such a directive does not exist, return false, otherwise return true.
+  template <typename Iterator>
+  bool applyToFirst(const omp::Clause *node, const mlir::ModuleOp &mod,
+                    llvm::iterator_range<Iterator> range);
+
+  // Apply a clause to the innermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToInnermost(const omp::Clause *node);
+
+  // Apply a clause to the outermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToOutermost(const omp::Clause *node);
+
+  template <typename Predicate>
+  bool applyIf(const omp::Clause *node, Predicate shouldApply);
+
+  bool applyToAll(const omp::Clause *node);
+
+  template <typename Clause>
+  bool applyClause(Clause &&clause, const omp::Clause *node);
+
+  bool applyClause(const omp::clause::Collapse &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Private &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Firstprivate &clause,
+                   const omp::Clause *);
+  bool applyClause(const omp::clause::Lastprivate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Shared &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Default &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::ThreadLimit &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Order &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Allocate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Reduction &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::If &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Linear &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Nowait &clause, const omp::Clause *);
+};
+} // namespace
+
+CompositeInfo::CompositeInfo(const mlir::ModuleOp &modOp,
+                             Fortran::semantics::SemanticsContext &semaCtx,
+                             Fortran::lower::pft::Evaluation &ev,
+                             llvm::omp::Directive compDir,
+                             const Fortran::parser::OmpClauseList &clauseList)
+    : semaCtx(semaCtx), mod(modOp), eval(ev),
+      clauses(omp::makeList(clauseList, semaCtx)) {
+  for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir))
+    leafs.push_back(DirectiveInfo{dir});
+
+  for (const omp::Clause &clause : clauses)
+    addClauseSymsToMap(clause, &clause);
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const DirectiveInfo &dirInfo) {
+  os << llvm::omp::getOpenMPDirectiveName(dirInfo.id);
+  for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) {
+    os << (index == 0 ? '\t' : ' ');
+    os << llvm::omp::getOpenMPClauseName(clause->id);
+  }
+  return os;
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const CompositeInfo &compInfo) {
+  for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs))
+    os << "leaf[" << index << "]: " << dirInfo << '\n';
+
+  os << "syms:\n";
+  for (const auto &[sym, clauses] : compInfo.syms) {
+    os << *sym << " -> {";
+    for (const auto *clause : clauses)
+      os << ' ' << llvm::omp::getOpenMPClauseName(clause->id);
+    os << " }\n";
+  }
+  os << "mapBases: {";
+  for (const auto &sym : compInfo.mapBases)
+    os << ' ' << *sym;
+  os << " }\n";
+  return os;
+}
+
+namespace detail {
+template <typename Container, typename Predicate>
+typename std::remove_reference_t<Container>::iterator
+find_unique(Container &&container, Predicate &&pred) {
+  auto first = std::find_if(container.begin(), container.end(), pred);
+  if (first == container.end())
+    return first;
+  auto second = std::find_if(std::next(first), container.end(), pred);
+  if (second == container.end())
+    return first;
+  return container.end();
+}
+} // namespace detail
+
+static Fortran::semantics::Symbol *
+getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) {
+  return eval.visit(Fortran::common::visitors{
+      [&](const Fortran::parser::DoConstruct &doLoop) {
+        if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
+          using LoopControl = Fortran::parser::LoopControl;
+          if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
+            static_assert(
+                std::is_same_v<decltype(bounds->name),
+                               Fortran::parser::Scalar<Fortran::parser::Name>>);
+            return bounds->name.thing.symbol;
+          }
+        }
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+      [](auto &&) {
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+  });
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::Object &object,
+                                       const omp::Clause *node) {
+  syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::ObjectList &objects,
+                                       const omp::Clause *node) {
+  for (auto &object : objects)
+    syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::SomeExpr &expr,
+                                       const omp::Clause *node) {
+  // Nothing to do for expressions.
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::clause::Map &item,
+                                       const omp::Clause *node) {
+  auto &objects = std::get<omp::ObjectList>(item.t);
+  addClauseSymsToMap(objects, node);
+  for (auto &object : objects) {
+    if (auto base = omp::getBaseObject(object, semaCtx))
+      mapBases.insert(base->sym);
+  }
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const std::optional<T> &item,
+                                       const omp::Clause *node) {
+  if (item)
+    addClauseSymsToMap(*item, node);
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const omp::List<T> &item,
+                                       const omp::Clause *node) {
+  for (auto &s : item)
+    addClauseSymsToMap(s, node);
+}
+
+template <typename... T, size_t... Is>
+void CompositeInfo::addClauseSymsToMap(const std::tuple<T...> &item,
+                                       const omp::Clause *node,
+                                       std::index_sequence<Is...>) {
+  (void)node; // Silence strange warning from GCC.
+  (addClauseSymsToMap(std::get<Is>(item), node), ...);
+}
+
+template <typename T,
+          std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for enums.
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for an empty class.
+}
+
+template <
+    typename T,
+    std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  addClauseSymsToMap(item.v, node);
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  constexpr size_t tuple_size =
+      std::tuple_size_v<llvm::remove_cvref_t<decltype(item.t)>>;
+  addClauseSymsToMap(item.t, node, std::make_index_sequence<tuple_size>{});
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  std::visit([&](auto &&s) { addClauseSymsToMap(s, node); }, item.u);
+}
+
+#if 1
+// Apply a clause to the only directive that allows it. If there are no
+// directives that allow it, or if there is more that one, do not apply
+// anything and return false, otherwise return true.
+bool CompositeInfo::applyToUnique(const omp::Clause *node) {
+  uint32_t version = getOpenMPVersion(mod);
+  auto unique = detail::find_unique(leafs, [=](const auto &dirInfo) {
+    return llvm::omp::isAllowedClauseForDirective(dirInfo.id, node->id,
+                                                  version);
+  });
+
+  if (unique != leafs.end()) {
+    unique->clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the first directive in given range that allows it.
+// If such a directive does not exist, return false, otherwise return true.
+template <typename Iterator>
+bool CompositeInfo::applyToFirst(const omp::Clause *node,
+                                 const mlir::ModuleOp &mod,
+                                 llvm::iterator_range<Iterator> range) {
+  if (range.empty())
+    return false;
+
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : range) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    dir.clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the innermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToInnermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::reverse(leafs));
+}
+
+// Apply a clause to the outermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToOutermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::iterator_range(leafs));
+}
+
+template <typename Predicate>
+bool CompositeInfo::applyIf(const omp::Clause *node, Predicate shouldApply) {
+  bool applied = false;
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : leafs) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    if (!shouldApply(dir))
+      continue;
+    dir.clauses.push_back(node);
+    applied = true;
+  }
+
+  return applied;
+}
+
+bool CompositeInfo::applyToAll(const omp::Clause *node) {
+  return applyIf(node, [](auto) { return true; });
+}
+
+template <typename Clause>
+bool CompositeInfo::applyClause(Clause &&clause, const omp::Clause *node) {
+  // The default behavior is to find the unique directive to which the
+  // given clause may be applied. If there are no such directives, or
+  // if there are multiple ones, flag an error.
+  // From "OpenMP Application Programming Interface", Version 5.2:
+  // S Some clauses are permitted only on a single leaf construct of the
+  // S combined or composite construct, in which case the effect is as if
+  // S the clause is applied to that specific construct. (p339, 31-33)
+  if (applyToUnique(node))
+    return true;
+
+  return false;
+}
+
+// COLLAPSE
+bool CompositeInfo::applyClause(const omp::clause::Collapse &clause,
+                                const omp::Clause *node) {
+  // Apply COLLAPSE to the innermost directive. If it's not one that
+  // allows it flag an error.
+  if (!leafs.empty()) {
+    DirectiveInfo &last = leafs.back();
+    uint32_t version = getOpenMPVersion(mod);
+
+    if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) {
+      last.clauses.push_back(node);
+      return true;
+    }
+  }
+
+  llvm::errs() << "Cannot apply COLLAPSE\n";
+  return false;
+}
+
+// PRIVATE
+bool CompositeInfo::applyClause(const omp::clause::Private &clause,
+                                const omp::Clause *node) {
+  if (applyToInnermost(node))
+    return true;
+  llvm::errs() << "Cannot apply PRIVATE\n";
+  return false;
+}
+
+// FIRSTPRIVATE
+bool CompositeInfo::applyClause(const omp::clause::Firstprivate &clause,
+                                const omp::Clause *node) {
+  bool applied = false;
+
+  // S Section 17.2
+  // S The effect of the firstprivate clause is as if it is applied to one
+  // S or more leaf constructs as follows:
+
+  // S - To the distribute construct if it is among the constituent constructs;
+  // S - To the teams construct if it is among the constituent constructs and
+  // S   the distribute construct is not;
+  auto hasDistribute = findDirective(llvm::omp::OMPD_distribute);
+  auto hasTeams = findDirective(llvm::omp::OMPD_teams);
+  if (hasDistribute != nullptr) {
+    hasDistribute->clauses.push_back(node);
+    applied = true;
+    // S If the teams construct is among the constituent constructs and the
+    // S effect is not as if the firstprivate clause is applied to it by the
+    // S above rules, then the effect is as if the shared clause with the
+    // S same list item is applied to the teams construct.
+    if (hasTeams != nullptr) {
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{clause.v});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasTeams->clauses.push_back(&n);
+    }
+  } else if (hasTeams != nullptr) {
+    hasTeams->clauses.push_back(node);
+    applied = true;
+  }
+
+  // S - To a worksharing construct that accepts the clause if one is among
+  // S   the constituent constructs;
+  auto findWorksharing = [&]() {
+    auto worksharing = getWorksharing();
+    for (DirectiveInfo &dir : leafs) {
+      auto found = llvm::find(worksharing, dir.id);
+      if (found != std::end(worksharing))
+        return &dir;
+    }
+    return static_cast<DirectiveInfo *>(nullptr);
+  };
+
+  auto hasWorksharing = findWorksharing();
+  if (hasWorksharing != nullptr) {
+    hasWorksharing->clauses.push_back(node);
+    applied = true;
+  }
+
+  // S - To the taskloop construct if it is among the constituent constructs;
+  auto hasTaskloop = findDirective(llvm::omp::OMPD_taskloop);
+  if (hasTaskloop != nullptr) {
+    hasTaskloop->clauses.push_back(node);
+    applied = true;
+  }
+
+  // S - To the parallel construct if it is among the constituent constructs
+  // S   and neither a taskloop construct nor a worksharing construct that
+  // S   accepts the clause is among them;
+  auto hasParallel = findDirective(llvm::omp::OMPD_parallel);
+  if (hasParallel != nullptr) {
+    if (hasTaskloop == nullptr && hasWorksharing == nullptr) {
+      hasParallel->clauses.push_back(node);
+      applied = true;
+    } else {
+      // S If the parallel construct is among the constituent constructs and
+      // S the effect is not as if the firstprivate clause is applied to it by
+      // S the above rules, then the effect is as if the shared clause with
+      // S the same list item is applied to the parallel construct.
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{clause.v});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasParallel->clauses.push_back(&n);
+    }
+  }
+
+  // S - To the target construct if it is among the constituent constructs
+  // S   and the same list item neither appears in a lastprivate clause nor
+  // S   is the base variable or base pointer of a list item that appears in
+  // S   a map clause.
+  auto inLastprivate = [&](const omp::Object &object) {
+    if (ClauseSet *set = findClauses(object)) {
+      return llvm::find_if(*set, [](const omp::Clause *c) {
+               return c->id == llvm::omp::Clause::OMPC_lastprivate;
+             }) != set->end();
+    }
+    return false;
+  };
+
+  auto hasTarget = findDirective(llvm::omp::OMPD_target);
+  if (hasTarget != nullptr) {
+    omp::ObjectList objects;
+    llvm::copy_if(
+        clause.v, std::back_inserter(objects), [&](const omp::Object &object) {
+          return !inLastprivate(object) && !mapBases.contains(object.sym);
+        });
+    if (!objects.empty()) {
+      auto firstp = omp::makeClause(llvm::omp::Clause::OMPC_firstprivate,
+                                    omp::clause::Firstprivate{objects});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(firstp));
+      hasTarget->clauses.push_back(&n);
+      applied = true;
+    }
+  }
+
+  return applied;
+}
+
+// LASTPRIVATE
+bool CompositeInfo::applyClause(const omp::clause::Lastprivate &clause,
+                                const omp::Clause *node) {
+  bool applied = false;
+
+  // S The effect of the lastprivate clause is as if it is applied to all leaf
+  // S constructs that permit the clause.
+  if (!applyToAll(node)) {
+    llvm::errs() << "Cannot apply LASTPRIVATE\n";
+    return false;
+  }
+
+  auto inFirstprivate = [&](const omp::Object &object) {
+    if (ClauseSet *set = findClauses(object)) {
+      return llvm::find_if(*set, [](const omp::Clause *c) {
+               return c->id == llvm::omp::Clause::OMPC_firstprivate;
+             }) != set->end();
+    }
+    return false;
+  };
+
+  // Prepare list of objects that could end up in a SHARED clause.
+  omp::ObjectList sharedObjects;
+  llvm::copy_if(
+      clause.v, std::back_inserter(sharedObjects),
+      [&](const omp::Object &object) { return !inFirstprivate(object); });
+
+  if (!sharedObjects.empty()) {
+    // S If the parallel construct is among the constituent constructs and the
+    // S list item is not also specified in the firstprivate clause, then the
+    // S effect of the lastprivate clause is as if the shared clause with the
+    // S same list item is applied to the parallel construct.
+    if (auto hasParallel = findDirective(llvm::omp::OMPD_parallel)) {
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{sharedObjects});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasParallel->clauses.push_back(&n);
+      applied = true;
+    }
+
+    // S If the teams construct is among the constituent constructs and the
+    // S list item is not also specified in the firstprivate clause, then the
+    // S effect of the lastprivate clause is as if the shared clause with the
+    // S same list item is applied to the teams construct.
+    if (auto hasTeams = findDirective(llvm::omp::OMPD_teams)) {
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{sharedObjects});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasTeams->clauses.push_back(&n);
+      applied = true;
+    }
+  }
+
+  // S If the target construct is among the constituent constructs and the
+  // S list item is not the base variable or base pointer of a list item that
+  // S appears in a map clause, the effect of the lastprivate clause is as if
+  // S the same list item appears in a map clause with a map-type of tofrom.
+  if (auto hasTarget = findDirective(llvm::omp::OMPD_target)) {
+    omp::ObjectList tofrom;
+    llvm::copy_if(clause.v, std::back_inserter(tofrom),
+                  [&](const omp::Object &object) {
+                    return !mapBases.contains(object.sym);
+                  });
+
+    if (!tofrom.empty()) {
+      auto mapType = omp::clause::Map::MapType{
+          {std::nullopt, omp::clause::Map::MapType::Type::Tofrom}};
+      auto map =
+          omp::makeClause(llvm::omp::Clause::OMPC_map,
+                          omp::clause::Map{{mapType, std::move(tofrom)}});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(map));
+      hasTarget->clauses.push_back(&n);
+      applied = true;
+    }
+  }
+
+  return applied;
+}
+
+// SHARED
+bool CompositeInfo::applyClause(const omp::clause::Shared &clause,
+                                const omp::Clause *node) {
+  // Apply SHARED to the all leafs that allow it.
+  if (applyToAll(node))
+    return true;
+  llvm::errs() << "Cannot apply SHARED\n";
+  return false;
+}
+
+// DEFAULT
+bool CompositeInfo::applyClause(const omp::clause::Default &clause,
+                                const omp::Clause *node) {
+  // Apply DEFAULT to the all leafs that allow it.
+  if (applyToAll(node))
+    return true;
+  llvm::errs() << "Cannot apply DEFAULT\n";
+  return false;
+}
+
+// THREAD_LIMIT
+bool CompositeInfo::applyClause(const omp::clause::ThreadLimit &clause,
+                                const omp::Clause *node) {
+  // Apply THREAD_LIMIT to the all leafs that allow it.
+  if (applyToAll(node))
+    return true;
+  llvm::errs() << "Cannot apply THREAD_LIMIT\n";
+  return false;
+}
+
+// ORDER
+bool CompositeInfo::applyClause(const omp::clause::Order &clause,
+                                const omp::Clause *node) {
+  // Apply ORDER to the all leafs that allow it.
+  if (applyToAll(node))
+    return true;
+  llvm::errs() << "Cannot apply ORDER\n";
+  return false;
+}
+
+// ALLOCATE
+bool CompositeInfo::applyClause(const omp::clause::Allocate &clause,
+                                const omp::Clause *node) {
+  // This one needs to be applied at the end, once we know which clauses are
+  // assigned to which leaf constructs.
+
+  // S The effect of the allocate clause is as if it is applied to all leaf
+  // S constructs that permit the clause and to which a data-sharing attribute
+  // S clause that may create a private copy of the same list item is applied.
+
+  auto canMakePrivateCopy = [](llvm::omp::Clause id) {
+    switch (id) {
+    case llvm::omp::Clause::OMPC_firstprivate:
+    case llvm::omp::Clause::OMPC_lastprivate:
+    case llvm::omp::Clause::OMPC_private:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  bool applied = applyIf(node, [&](const DirectiveInfo &dir) {
+    return llvm::any_of(dir.clauses, [&](const omp::Clause *n) {
+      return canMakePrivateCopy(n->id);
+    });
+  });
+
+  return applied;
+}
+
+// REDUCTION
+bool CompositeInfo::applyClause(const omp::clause::Reduction &clause,
+                                const omp::Clause *node) {
+  // S The effect of the reduction clause is as if it is applied to all leaf
+  // S constructs that permit the clause, except for the following constructs:
+  // S - The parallel construct, when combined with the sections, worksharing-
+  // S   loop, loop, or taskloop construct; and
+  // S - The teams construct, when combined with the loop construct.
+  bool applyToParallel = true, applyToTeams = true;
+
+  auto hasParallel = findDirective(llvm::omp::Directive::OMPD_parallel);
+  if (hasParallel) {
+    auto exclusions = llvm::concat<const llvm::omp::Directive>(
+        getWorksharingLoop(), llvm::ArrayRef{
+                                  llvm::omp::Directive::OMPD_loop,
+                                  llvm::omp::Directive::OMPD_sections,
+                                  llvm::omp::Directive::OMPD_taskloop,
+                              });
+    auto present = [&](llvm::omp::Directive id) {
+      return findDirective(id) != nullptr;
+    };
+
+    if (llvm::any_of(exclusions, present))
+      applyToParallel = false;
+  }
+
+  auto hasTeams = findDirective(llvm::omp::Directive::OMPD_teams);
+  if (hasTeams) {
+    // The only exclusion is OMPD_loop.
+    if (findDirective(llvm::omp::Directive::OMPD_loop))
+      applyToTeams = false;
+  }
+
+  auto &objects = std::get<omp::ObjectList>(clause.t);
+
+  omp::ObjectList sharedObjects;
+  llvm::transform(objects, std::back_inserter(sharedObjects),
+                  [&](const omp::Object &object) {
+                    auto maybeBase = getBaseObject(object, semaCtx);
+                    return maybeBase ? *maybeBase : object;
+                  });
+
+  // S For the parallel and teams constructs above, the effect of the
+  // S reduction clause instead is as if each list item or, for any list
+  // S item that is an array item, its corresponding base array or base
+  // S pointer appears in a shared clause for the construct.
+  if (!sharedObjects.empty()) {
+    if (hasParallel && !applyToParallel) {
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{sharedObjects});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasParallel->clauses.push_back(&n);
+    }
+    if (hasTeams && !applyToTeams) {
+      auto shared = omp::makeClause(llvm::omp::Clause::OMPC_shared,
+                                    omp::clause::Shared{sharedObjects});
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(shared));
+      hasTeams->clauses.push_back(&n);
+    }
+  }
+
+  // TODO(not implemented in parser yet): Apply the following.
+  // S If the task reduction-modifier is specified, the effect is as if
+  // S it only modifies the behavior of the reduction clause on the innermost
+  // S leaf construct that accepts the modifier (see Section 5.5.8). If the
+  // S inscan reduction-modifier is specified, the effect is as if it modifies
+  // S the behavior of the reduction clause on all constructs of the combined
+  // S construct to which the clause is applied and that accept the modifier.
+
+  bool applied = applyIf(node, [&](DirectiveInfo &dir) {
+    if (!applyToParallel && &dir == hasParallel)
+      return false;
+    if (!applyToTeams && &dir == hasTeams)
+      return false;
+    return true;
+  });
+
+  // S If a list item in a reduction clause on a combined target construct
+  // S does not have the same base variable or base pointer as a list item
+  // S in a map clause on the construct, then the effect is as if the list
+  // S item in the reduction clause appears as a list item in a map clause
+  // S with a map-type of tofrom.
+  auto hasTarget = findDirective(llvm::omp::Directive::OMPD_target);
+  if (hasTarget && leafs.size() > 1) {
+    omp::ObjectList tofrom;
+    llvm::copy_if(objects, std::back_inserter(tofrom),
+                  [&](const omp::Object &object) {
+                    if (auto maybeBase = getBaseObject(object, semaCtx))
+                      return !mapBases.contains(maybeBase->sym);
+                    return !mapBases.contains(object.sym); // XXX is this ok?
+                  });
+    if (!tofrom.empty()) {
+      auto mapType = omp::clause::Map::MapType{
+          {std::nullopt, omp::clause::Map::MapType::Type::Tofrom}};
+      auto map =
+          omp::makeClause(llvm::omp::Clause::OMPC_map,
+                          omp::clause::Map{{mapType, std::move(tofrom)}});
+
+      const omp::Clause &n = *extras.insert(extras.end(), std::move(map));
+      hasTarget->clauses.push_back(&n);
+      applied = true;
+    }
+  }
+
+  return applied;
+}
+
+// IF
+bool CompositeInfo::applyClause(const omp::clause::If &clause,
+                                const omp::Clause *node) {
+  using DirectiveNameModifier = omp::clause::If::DirectiveNameModifier;
+  auto &modifier = std::get<std::optional<DirectiveNameModifier>>(clause.t);
+
+  if (modifier) {
+    llvm::omp::Directive dirId = llvm::omp::Directive::OMPD_unknown;
+
+    switch (*modifier) {
+    case DirectiveNameModifier::Parallel:
+      dirId = llvm::omp::Directive::OMPD_parallel;
+      break;
+    case DirectiveNameModifier::Simd:
+      dirId = llvm::omp::Directive::OMPD_simd;
+      break;
+    case DirectiveNameModifier::Target:
+      dirId = llvm::omp::Directive::OMPD_target;
+      break;
+    case DirectiveNameModifier::Task:
+      dirId = llvm::omp::Directive::OMPD_task;
+      break;
+    case DirectiveNameModifier::Taskloop:
+      dirId = llvm::omp::Directive::OMPD_taskloop;
+      break;
+    case DirectiveNameModifier::Teams:
+      dirId = llvm::omp::Directive::OMPD_teams;
+      break;
+
+    case DirectiveNameModifier::TargetData:
+    case DirectiveNameModifier::TargetEnterData:
+    case DirectiveNameModifier::TargetExitData:
+    case DirectiveNameModifier::TargetUpdate:
+    default:
+      llvm::errs() << "Invalid modifier in IF clause\n";
+      return false;
+    }
+
+    if (auto *hasDir = findDirective(dirId)) {
+      hasDir->clauses.push_back(node);
+      return true;
+    }
+    llvm::errs() << "Directive from modifier not found\n";
+    return false;
+  }
+
+  if (applyToAll(node))
+    return true;
+
+  llvm::errs() << "Cannot apply IF\n";
+  return false;
+}
+
+// LINEAR
+bool CompositeInfo::applyClause(const omp::clause::Linear &clause,
+                                const omp::Clause *node) {
+  // S The effect of the linear clause is as if it is applied to the innermost
+  // S leaf construct.
+  if (applyToInnermost(node)) {
+    llvm::errs() << "Cannot apply LINEAR\n";
+    return false;
+  }
+
+  // The rest is about SIMD.
+  if (!findDirective(llvm::omp::OMPD_simd))
+    return true;
+
+  // S Additionally, if the list item is not the iteration variable of a
+  // S simd or worksharing-loop SIMD construct, the effect on the outer leaf
+  // S constructs is as if the list item was specified in firstprivate and
+  // S lastprivate clauses on the combined or composite construct, [...]
+  //
+  // S If a list item of the linear clause is the iteration variable of a
+  // S simd or worksharing-loop SIMD construct and it is not declared in
+  // S the construct, the effect on the outer leaf constructs is as if the
+  // S list item was specified in a lastprivate clause on the combined or
+  // S composite construct [...]
+
+  // It's not clear how an object can be listed in a clause AND be the
+  // iteration variable of a construct in which is it declared. If an
+  // object is declared in the construct, then the declaration is located
+  // after the clause listing it.
+
+  Fortran::semantics::Symbol *iterVarSym = getIterationVariableSymbol(eval);
+  const auto &objects = std::get<omp::ObjectList>(clause.t);
+
+  // Lists of objects that will be used to construct FIRSTPRIVATE and
+  // LASTPRIVATE clauses.
+  omp::ObjectList first, last;
+
+  for (const omp::Object &object : objects) {
+    last.push_back(object);
+    if (object.sym != iterVarSym)
+      first.push_back(object);
+  }
+
+  if (!first.empty()) {
+    auto firstp = omp::makeClause(llvm::omp::Clause::OMPC_firstprivate,
+                                  omp::clause::Firstprivate{first});
+    clauses.push_back(std::move(firstp)); // Appending to the main clause list.
+  }
+  if (!last.empty()) {
+    auto lastp = omp::makeClause(llvm::omp::Clause::OMPC_lastprivate,
+                                 omp::clause::Lastprivate{last});
+    clauses.push_back(std::move(lastp)); // Appending to the main clause list.
+  }
+  return true;
+}
+
+// NOWAIT
+bool CompositeInfo::applyClause(const omp::clause::Nowait &clause,
+                                const omp::Clause *node) {
+  if (applyToOutermost(node))
+    return true;
+  llvm::errs() << "Cannot apply NOWAIT\n";
+  return false;
+}
+
+bool CompositeInfo::split() {
+  bool success = true;
+
+  // First we need to apply LINEAR, because it can generate additional
+  // FIRSTPRIVATE and LASTPRIVATE clauses that apply to the combined/
+  // composite construct.
+  // Collect them separately, because they may modify the clause list.
+  llvm::SmallVector<const omp::Clause *> linears;
+  for (const omp::Clause &node : clauses) {
+    if (node.id == llvm::omp::Clause::OMPC_linear)
+      linears.push_back(&node);
+  }
+  for (const auto *node : linears) {
+    success =
+        success && applyClause(std::get<omp::clause::Linear>(node->u), node);
+  }
+
+  // ALLOCATE clauses need to be applied last since they need to see
+  // which directives have data-privatizing clauses.
+  auto skip = [](const omp::Clause *node) {
+    switch (node->id) {
+    case llvm::omp::Clause::OMPC_allocate:
+    case llvm::omp::Clause::OMPC_linear:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  // Apply (almost) all clauses.
+  for (const omp::Clause &node : clauses) {
+    if (skip(&node))
+      continue;
+    success =
+        success &&
+        std::visit([&](auto &&s) { return applyClause(s, &node); }, node.u);
+  }
+
+  // Apply ALLOCATE.
+  for (const omp::Clause &node : clauses) {
+    if (node.id != llvm::omp::Clause::OMPC_allocate)
+      continue;
+    success =
+        success &&
+        std::visit([&](auto &&s) { return applyClause(s, &node); }, node.u);
+  }
+
+  return success;
+}
+#endif
+
+static void splitCompositeConstruct(
+    const mlir::ModuleOp &modOp, Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive compDir,
+    const Fortran::parser::OmpClauseList &clauseList) {
+  //  llvm::errs() << "composite name:"
+  //               << llvm::omp::getOpenMPDirectiveName(compDir) << '\n';
+  //  llvm::errs() << "clause list:";
+  for (auto &clause : clauseList.v) {
+    //    std::visit([&](auto &&s) { omp::clause::make(s, semaCtx); },
+    //    clause.u); llvm::errs() << ' ' <<
+    //    llvm::omp::getOpenMPClauseName(getClauseId(clause));
+  }
+  //  llvm::errs() << '\n';
+
+  CompositeInfo compInfo(modOp, semaCtx, eval, compDir, clauseList);
+  //  llvm::errs() << "compInfo.1\n" << compInfo << '\n';
+
+  bool success = compInfo.split();
+
+  // Dump
+  //  llvm::errs() << "success:" << success << '\n';
+  //  llvm::errs() << "compInfo.2\n" << compInfo << '\n';
+}
+
 //===----------------------------------------------------------------------===//
 // DataSharingProcessor
 //===----------------------------------------------------------------------===//
@@ -4548,6 +5568,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
   const auto &beginLoopDirective =
       std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
+  // Test call
+  splitCompositeConstruct(converter.getFirOpBuilder().getModule(), semaCtx,
+                          eval, std::get<0>(beginLoopDirective.t).v,
+                          std::get<1>(beginLoopDirective.t));
   const auto &loopOpClauseList =
       std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
   mlir::Location currentLocation =
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index c9358c83e795c4..bdae1731260de1 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -359,7 +359,6 @@ static mlir::LogicalResult convertFortranSourceToMLIR(
       semanticsContext.targetCharacteristics(), parsing.allCooked(),
       targetTriple, kindMap, loweringOptions, {},
       semanticsContext.languageFeatures(), targetMachine);
-  burnside.lower(parseTree, semanticsContext);
   mlir::ModuleOp mlirModule = burnside.getModule();
   if (enableOpenMP) {
     if (enableOpenMPGPU && !enableOpenMPDevice) {
@@ -375,6 +374,7 @@ static mlir::LogicalResult convertFortranSourceToMLIR(
     setOffloadModuleInterfaceAttributes(mlirModule, offloadModuleOpts);
     setOpenMPVersionAttribute(mlirModule, setOpenMPVersion);
   }
+  burnside.lower(parseTree, semanticsContext);
   std::error_code ec;
   std::string outputName = outputFilename;
   if (!outputName.size())
diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
index 31578710365b21..24eb54e75c96ba 100644
--- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
+++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
@@ -152,6 +152,10 @@ class Directive<string d> {
   // List of clauses that are required.
   list<VersionedClause> requiredClauses = [];
 
+  // List of leaf constituent directives in the order in which they appear
+  // in the combined/composite directive.
+  list<Directive> leafs = [];
+
   // Set directive used by default when unknown.
   bit isDefault = false;
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 1481328bf483b8..534ab58985b57d 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -773,6 +773,7 @@ def OMP_TargetParallel : Directive<"target parallel"> {
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];
+  let leafs = [OMP_Target, OMP_Parallel];
 }
 def OMP_TargetParallelFor : Directive<"target parallel for"> {
   let allowedClauses = [
@@ -805,6 +806,7 @@ def OMP_TargetParallelFor : Directive<"target parallel for"> {
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];
+  let leafs = [OMP_Target, OMP_Parallel, OMP_For];
 }
 def OMP_TargetParallelDo : Directive<"target parallel do"> {
   let allowedClauses = [
@@ -835,6 +837,7 @@ def OMP_TargetParallelDo : Directive<"target parallel do"> {
     VersionedClause<OMPC_NoWait>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Target, OMP_Parallel, OMP_Do];
 }
 def OMP_TargetUpdate : Directive<"target update"> {
   let allowedClauses = [
@@ -848,6 +851,11 @@ def OMP_TargetUpdate : Directive<"target update"> {
     VersionedClause<OMPC_NoWait>
   ];
 }
+def OMP_masked : Directive<"masked"> {
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Filter>
+  ];
+}
 def OMP_ParallelFor : Directive<"parallel for"> {
   let allowedClauses = [
     VersionedClause<OMPC_If>,
@@ -868,6 +876,7 @@ def OMP_ParallelFor : Directive<"parallel for"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_For];
 }
 def OMP_ParallelDo : Directive<"parallel do"> {
   let allowedClauses = [
@@ -889,6 +898,7 @@ def OMP_ParallelDo : Directive<"parallel do"> {
     VersionedClause<OMPC_Collapse>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Parallel, OMP_Do];
 }
 def OMP_ParallelForSimd : Directive<"parallel for simd"> {
   let allowedClauses = [
@@ -914,6 +924,7 @@ def OMP_ParallelForSimd : Directive<"parallel for simd"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_For, OMP_Simd];
 }
 def OMP_ParallelDoSimd : Directive<"parallel do simd"> {
   let allowedClauses = [
@@ -940,6 +951,7 @@ def OMP_ParallelDoSimd : Directive<"parallel do simd"> {
     VersionedClause<OMPC_SimdLen>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Parallel, OMP_Do, OMP_Simd];
 }
 def OMP_ParallelMaster : Directive<"parallel master"> {
   let allowedClauses = [
@@ -955,6 +967,7 @@ def OMP_ParallelMaster : Directive<"parallel master"> {
     VersionedClause<OMPC_Allocate>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_Master];
 }
 def OMP_ParallelMasked : Directive<"parallel masked"> {
   let allowedClauses = [
@@ -971,6 +984,7 @@ def OMP_ParallelMasked : Directive<"parallel masked"> {
     VersionedClause<OMPC_Filter>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_masked];
 }
 def OMP_ParallelSections : Directive<"parallel sections"> {
   let allowedClauses = [
@@ -989,6 +1003,7 @@ def OMP_ParallelSections : Directive<"parallel sections"> {
     VersionedClause<OMPC_If>,
     VersionedClause<OMPC_NumThreads>
   ];
+  let leafs = [OMP_Parallel, OMP_Sections];
 }
 def OMP_ForSimd : Directive<"for simd"> {
   let allowedClauses = [
@@ -1009,6 +1024,7 @@ def OMP_ForSimd : Directive<"for simd"> {
     VersionedClause<OMPC_NonTemporal, 50>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_For, OMP_Simd];
 }
 def OMP_DoSimd : Directive<"do simd"> {
   let allowedClauses = [
@@ -1029,6 +1045,7 @@ def OMP_DoSimd : Directive<"do simd"> {
     VersionedClause<OMPC_NoWait>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Do, OMP_Simd];
 }
 def OMP_CancellationPoint : Directive<"cancellation point"> {}
 def OMP_DeclareReduction : Directive<"declare reduction"> {}
@@ -1106,6 +1123,7 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
     VersionedClause<OMPC_GrainSize>,
     VersionedClause<OMPC_NumTasks>
   ];
+  let leafs = [OMP_TaskLoop, OMP_Simd];
 }
 def OMP_Distribute : Directive<"distribute"> {
   let allowedClauses = [
@@ -1158,6 +1176,7 @@ def OMP_DistributeParallelFor : Directive<"distribute parallel for"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Distribute, OMP_Parallel, OMP_For];
 }
 def OMP_DistributeParallelDo : Directive<"distribute parallel do"> {
   let allowedClauses = [
@@ -1181,6 +1200,7 @@ def OMP_DistributeParallelDo : Directive<"distribute parallel do"> {
     VersionedClause<OMPC_Ordered>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Distribute, OMP_Parallel, OMP_Do];
 }
 def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> {
   let allowedClauses = [
@@ -1206,6 +1226,7 @@ def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd];
 }
 def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> {
   let allowedClauses = [
@@ -1230,6 +1251,7 @@ def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> {
     VersionedClause<OMPC_NonTemporal>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd];
 }
 def OMP_DistributeSimd : Directive<"distribute simd"> {
   let allowedClauses = [
@@ -1256,6 +1278,7 @@ def OMP_DistributeSimd : Directive<"distribute simd"> {
     VersionedClause<OMPC_SimdLen>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Distribute, OMP_Simd];
 }
 
 def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> {
@@ -1293,6 +1316,7 @@ def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> {
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];
+  let leafs = [OMP_Target, OMP_Parallel, OMP_For, OMP_Simd];
 }
 def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> {
   let allowedClauses = [
@@ -1324,6 +1348,7 @@ def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_UsesAllocators>
   ];
+  let leafs = [OMP_Target, OMP_Parallel, OMP_Do, OMP_Simd];
 }
 def OMP_TargetSimd : Directive<"target simd"> {
   let allowedClauses = [
@@ -1358,6 +1383,7 @@ def OMP_TargetSimd : Directive<"target simd"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];
+  let leafs = [OMP_Target, OMP_Simd];
 }
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [
@@ -1377,6 +1403,7 @@ def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_If>
   ];
+  let leafs = [OMP_Teams, OMP_Distribute];
 }
 def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> {
   let allowedClauses = [
@@ -1402,6 +1429,7 @@ def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> {
     VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Teams, OMP_Distribute, OMP_Simd];
 }
 
 def OMP_TeamsDistributeParallelForSimd :
@@ -1430,6 +1458,7 @@ def OMP_TeamsDistributeParallelForSimd :
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd];
 }
 def OMP_TeamsDistributeParallelDoSimd :
     Directive<"teams distribute parallel do simd"> {
@@ -1458,6 +1487,7 @@ def OMP_TeamsDistributeParallelDoSimd :
     VersionedClause<OMPC_SimdLen>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd];
 }
 def OMP_TeamsDistributeParallelFor :
     Directive<"teams distribute parallel for"> {
@@ -1481,6 +1511,7 @@ def OMP_TeamsDistributeParallelFor :
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For];
 }
 def OMP_TeamsDistributeParallelDo :
     Directive<"teams distribute parallel do"> {
@@ -1507,6 +1538,7 @@ let allowedOnceClauses = [
     VersionedClause<OMPC_Schedule>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do];
 }
 def OMP_TargetTeams : Directive<"target teams"> {
   let allowedClauses = [
@@ -1534,6 +1566,7 @@ def OMP_TargetTeams : Directive<"target teams"> {
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_OMX_Bare>,
   ];
+  let leafs = [OMP_Target, OMP_Teams];
 }
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
@@ -1562,6 +1595,7 @@ def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
     VersionedClause<OMPC_DistSchedule>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute];
 }
 
 def OMP_TargetTeamsDistributeParallelFor :
@@ -1596,6 +1630,7 @@ def OMP_TargetTeamsDistributeParallelFor :
   let allowedOnceClauses = [
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For];
 }
 def OMP_TargetTeamsDistributeParallelDo :
     Directive<"target teams distribute parallel do"> {
@@ -1630,6 +1665,7 @@ def OMP_TargetTeamsDistributeParallelDo :
     VersionedClause<OMPC_Schedule>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do];
 }
 def OMP_TargetTeamsDistributeParallelForSimd :
     Directive<"target teams distribute parallel for simd"> {
@@ -1668,6 +1704,7 @@ def OMP_TargetTeamsDistributeParallelForSimd :
   let allowedOnceClauses = [
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_For, OMP_Simd];
 }
 def OMP_TargetTeamsDistributeParallelDoSimd :
     Directive<"target teams distribute parallel do simd"> {
@@ -1706,6 +1743,7 @@ def OMP_TargetTeamsDistributeParallelDoSimd :
     VersionedClause<OMPC_SimdLen>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Parallel, OMP_Do, OMP_Simd];
 }
 def OMP_TargetTeamsDistributeSimd :
     Directive<"target teams distribute simd"> {
@@ -1740,6 +1778,7 @@ def OMP_TargetTeamsDistributeSimd :
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
 }
 def OMP_Allocate : Directive<"allocate"> {
   let allowedOnceClauses = [
@@ -1781,6 +1820,7 @@ def OMP_MasterTaskloop : Directive<"master taskloop"> {
     VersionedClause<OMPC_InReduction>,
     VersionedClause<OMPC_Allocate>
   ];
+  let leafs = [OMP_Master, OMP_TaskLoop];
 }
 def OMP_MaskedTaskloop : Directive<"masked taskloop"> {
   let allowedClauses = [
@@ -1803,6 +1843,7 @@ def OMP_MaskedTaskloop : Directive<"masked taskloop"> {
     VersionedClause<OMPC_Allocate>,
     VersionedClause<OMPC_Filter>
   ];
+  let leafs = [OMP_masked, OMP_TaskLoop];
 }
 def OMP_ParallelMasterTaskloop :
     Directive<"parallel master taskloop"> {
@@ -1828,6 +1869,7 @@ def OMP_ParallelMasterTaskloop :
     VersionedClause<OMPC_Copyin>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_Master, OMP_TaskLoop];
 }
 def OMP_ParallelMaskedTaskloop :
     Directive<"parallel masked taskloop"> {
@@ -1854,6 +1896,7 @@ def OMP_ParallelMaskedTaskloop :
     VersionedClause<OMPC_Filter>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_masked, OMP_TaskLoop];
 }
 def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> {
   let allowedClauses = [
@@ -1881,6 +1924,7 @@ def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> {
     VersionedClause<OMPC_NonTemporal, 50>,
     VersionedClause<OMPC_Order, 50>
   ];
+  let leafs = [OMP_Master, OMP_TaskLoop, OMP_Simd];
 }
 def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> {
   let allowedClauses = [
@@ -1909,6 +1953,7 @@ def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> {
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_Filter>
   ];
+  let leafs = [OMP_masked, OMP_TaskLoop, OMP_Simd];
 }
 def OMP_ParallelMasterTaskloopSimd :
     Directive<"parallel master taskloop simd"> {
@@ -1940,6 +1985,7 @@ def OMP_ParallelMasterTaskloopSimd :
     VersionedClause<OMPC_Order, 50>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_Master, OMP_TaskLoop, OMP_Simd];
 }
 def OMP_ParallelMaskedTaskloopSimd :
     Directive<"parallel masked taskloop simd"> {
@@ -1972,6 +2018,7 @@ def OMP_ParallelMaskedTaskloopSimd :
     VersionedClause<OMPC_Filter>,
     VersionedClause<OMPC_OMPX_Attribute>,
   ];
+  let leafs = [OMP_Parallel, OMP_masked, OMP_TaskLoop, OMP_Simd];
 }
 def OMP_Depobj : Directive<"depobj"> {
   let allowedClauses = [
@@ -2003,6 +2050,7 @@ def OMP_scope : Directive<"scope"> {
     VersionedClause<OMPC_NoWait, 51>
   ];
 }
+def OMP_Workshare : Directive<"workshare"> {}
 def OMP_ParallelWorkshare : Directive<"parallel workshare"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2018,8 +2066,8 @@ def OMP_ParallelWorkshare : Directive<"parallel workshare"> {
     VersionedClause<OMPC_NumThreads>,
     VersionedClause<OMPC_ProcBind>
   ];
+  let leafs = [OMP_Parallel, OMP_Workshare];
 }
-def OMP_Workshare : Directive<"workshare"> {}
 def OMP_EndDo : Directive<"end do"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_NoWait>
@@ -2069,11 +2117,6 @@ def OMP_dispatch : Directive<"dispatch"> {
     VersionedClause<OMPC_Nocontext>
   ];
 }
-def OMP_masked : Directive<"masked"> {
-  let allowedOnceClauses = [
-    VersionedClause<OMPC_Filter>
-  ];
-}
 def OMP_loop : Directive<"loop"> {
   let allowedClauses = [
     VersionedClause<OMPC_LastPrivate>,
@@ -2104,6 +2147,7 @@ def OMP_teams_loop : Directive<"teams loop"> {
     VersionedClause<OMPC_Order>,
     VersionedClause<OMPC_ThreadLimit>,
   ];
+  let leafs = [OMP_Teams, OMP_loop];
 }
 def OMP_target_teams_loop : Directive<"target teams loop"> {
   let allowedClauses = [
@@ -2133,6 +2177,7 @@ def OMP_target_teams_loop : Directive<"target teams loop"> {
     VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
   ];
+  let leafs = [OMP_Target, OMP_Teams, OMP_loop];
 }
 def OMP_parallel_loop : Directive<"parallel loop"> {
   let allowedClauses = [
@@ -2154,6 +2199,7 @@ def OMP_parallel_loop : Directive<"parallel loop"> {
     VersionedClause<OMPC_Order>,
     VersionedClause<OMPC_ProcBind>,
   ];
+  let leafs = [OMP_Parallel, OMP_loop];
 }
 def OMP_target_parallel_loop : Directive<"target parallel loop"> {
   let allowedClauses = [
@@ -2185,11 +2231,13 @@ def OMP_target_parallel_loop : Directive<"target parallel loop"> {
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_ThreadLimit, 51>,
   ];
+  let leafs = [OMP_Target, OMP_Parallel, OMP_loop];
 }
 def OMP_Metadirective : Directive<"metadirective"> {
   let allowedClauses = [VersionedClause<OMPC_When>];
   let allowedOnceClauses = [VersionedClause<OMPC_Default>];
 }
+
 def OMP_Unknown : Directive<"unknown"> {
   let isDefault = true;
 }
diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h
index c86018715a48a1..f655e584f891e1 100644
--- a/llvm/include/llvm/TableGen/DirectiveEmitter.h
+++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h
@@ -121,6 +121,10 @@ class Directive : public BaseRecord {
   std::vector<Record *> getRequiredClauses() const {
     return Def->getValueAsListOfDefs("requiredClauses");
   }
+
+  std::vector<Record *> getLeafConstructs() const {
+    return Def->getValueAsListOfDefs("leafs");
+  }
 };
 
 // Wrapper class that contains Clause's information defined in DirectiveBase.td
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index b6aee665f8ee0b..7cb2a5cbe95954 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -186,6 +186,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
 
   if (DirLang.hasEnableBitmaskEnumInNamespace())
     OS << "\n#include \"llvm/ADT/BitmaskEnum.h\"\n";
+  OS << "#include \"llvm/ADT/SmallVector.h\"\n";
 
   OS << "\n";
   OS << "namespace llvm {\n";
@@ -231,6 +232,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   OS << "bool isAllowedClauseForDirective(Directive D, "
      << "Clause C, unsigned Version);\n";
   OS << "\n";
+  OS << "const llvm::SmallVector<Directive> &getLeafConstructs(Directive D);\n";
   if (EnumHelperFuncs.length() > 0) {
     OS << EnumHelperFuncs;
     OS << "\n";
@@ -435,6 +437,78 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
   OS << "}\n"; // End of function isAllowedClauseForDirective
 }
 
+// Generate the getLeafConstructs function implementation.
+static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
+                                      raw_ostream &OS) {
+  auto getQualifiedName = [&](StringRef Formatted) -> std::string {
+    return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
+            "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
+        .str();
+  };
+
+  // For each list of leafs, generate a static local object, then
+  // return a reference to that object for a given directive, e.g.
+  //
+  //   static ListTy leafConstructs_A_B = { A, B };
+  //   static ListTy leafConstructs_C_D_E = { C, D, E };
+  //   switch (Dir) {
+  //     case A_B:
+  //       return leafConstructs_A_B;
+  //     case C_D_E:
+  //       return leafConstructs_C_D_E;
+
+  // Map from a record that defines a directive to the name of the
+  // local object with the list of its leafs.
+  DenseMap<Record *, std::string> ListNames;
+
+  std::string DirectiveTypeName =
+      std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
+  std::string DirectiveListTypeName =
+      std::string("llvm::SmallVector<") + DirectiveTypeName + ">";
+
+  // const Container &llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
+  OS << "const " << DirectiveListTypeName
+     << " &llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
+     << DirectiveTypeName << " Dir) ";
+  OS << "{\n";
+
+  // Generate the locals.
+  for (Record *R : DirLang.getDirectives()) {
+    Directive Dir{R};
+
+    std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
+    if (LeafConstructs.empty())
+      continue;
+
+    std::string ListName = "leafConstructs_" + Dir.getFormattedName();
+    OS << "  static " << DirectiveListTypeName << ' ' << ListName << " {\n";
+    for (Record *L : LeafConstructs) {
+      Directive LeafDir{L};
+      OS << "    " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+    }
+    OS << "  };\n";
+    ListNames.insert(std::make_pair(R, std::move(ListName)));
+  }
+
+  OS << "  static " << DirectiveListTypeName << " nothing {};\n";
+
+  OS << '\n';
+  OS << "  switch (Dir) {\n";
+  for (Record *R : DirLang.getDirectives()) {
+    auto F = ListNames.find(R);
+    if (F == ListNames.end())
+      continue;
+
+    Directive Dir{R};
+    OS << "  case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
+    OS << "    return " << F->second << ";\n";
+  }
+  OS << "  default:\n";
+  OS << "    return nothing;\n";
+  OS << "  } // switch (Dir)\n";
+  OS << "}\n";
+}
+
 // Generate a simple enum set with the give clauses.
 static void GenerateClauseSet(const std::vector<Record *> &Clauses,
                               raw_ostream &OS, StringRef ClauseSetPrefix,
@@ -876,6 +950,9 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
 
   // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
   GenerateIsAllowedClause(DirLang, OS);
+
+  // getLeafConstructs(Directive D)
+  GenerateGetLeafConstructs(DirLang, OS);
 }
 
 // Generate the implemenation section for the enumeration in the directive



More information about the llvm-branch-commits mailing list