[llvm-branch-commits] [llvm] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct (PR #87247)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 1 08:23:11 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible).

---
Full diff: https://github.com/llvm/llvm-project/pull/87247.diff


5 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.h (+7) 
- (modified) llvm/lib/Frontend/OpenMP/OMP.cpp (+63-1) 
- (modified) llvm/unittests/Frontend/CMakeLists.txt (+1) 
- (added) llvm/unittests/Frontend/OpenMPComposeTest.cpp (+40) 
- (modified) llvm/utils/TableGen/DirectiveEmitter.cpp (+124-70) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..4ed47f15dfe59e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,11 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h.inc"
 
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D);
+Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+} // namespace llvm::omp
+
 #endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..dd99d3d074fd1e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,12 +8,74 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h"
 
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
 
+#include <algorithm>
+#include <iterator>
+#include <type_traits>
+
 using namespace llvm;
-using namespace omp;
+using namespace llvm::omp;
 
 #define GEN_DIRECTIVES_IMPL
 #include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D) {
+  auto Idx = static_cast<int>(D);
+  if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+    return {};
+  const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+  return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+}
+
+Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
+  if (Parts.empty())
+    return OMPD_unknown;
+
+  // Parts don't have to be leafs, so expand them into leafs first.
+  // Store the expanded leafs in the same format as rows in the leaf
+  // table (generated by tablegen).
+  SmallVector<Directive> RawLeafs(2);
+  for (Directive P : Parts) {
+    ArrayRef<Directive> Ls = getLeafConstructs(P);
+    if (!Ls.empty())
+      RawLeafs.append(Ls.begin(), Ls.end());
+    else
+      RawLeafs.push_back(P);
+  }
+
+  auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
+  if (GivenLeafs.size() == 1)
+    return GivenLeafs.front();
+  RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
+
+  auto Iter = llvm::lower_bound(
+      LeafConstructTable,
+      static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
+      [](const auto *RowA, const auto *RowB) {
+        const auto *BeginA = &RowA[2];
+        const auto *EndA = BeginA + static_cast<int>(RowA[1]);
+        const auto *BeginB = &RowB[2];
+        const auto *EndB = BeginB + static_cast<int>(RowB[1]);
+        if (BeginA == EndA && BeginB == EndB)
+          return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
+        return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
+      });
+
+  if (Iter == std::end(LeafConstructTable))
+    return OMPD_unknown;
+
+  // Verify that we got a match.
+  Directive Found = (*Iter)[0];
+  ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
+  if (FoundLeafs == GivenLeafs)
+    return Found;
+  return OMPD_unknown;
+}
+} // namespace llvm::omp
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index c6f60142d6276a..ddb6a16cbb984e 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
   OpenMPContextTest.cpp
   OpenMPIRBuilderTest.cpp
   OpenMPParsingTest.cpp
+  OpenMPComposeTest.cpp
 
   DEPENDS
   acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
new file mode 100644
index 00000000000000..2dc35aca8842e9
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -0,0 +1,40 @@
+//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace llvm::omp;
+
+TEST(Composition, GetLeafConstructs) {
+  ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
+  ASSERT_EQ(L1, (ArrayRef<Directive>{}));
+  ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
+  ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
+  ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
+  ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
+}
+
+TEST(Composition, GetCompoundConstruct) {
+  Directive C1 = getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
+  ASSERT_EQ(C1, OMPD_target_teams_distribute);
+  Directive C2 = getCompoundConstruct({OMPD_target});
+  ASSERT_EQ(C2, OMPD_target);
+  Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
+  ASSERT_EQ(C3, OMPD_unknown);
+  Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+  ASSERT_EQ(C4, OMPD_target_teams_distribute);
+  Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+  ASSERT_EQ(C5, OMPD_target_teams_distribute);
+  Directive C6 = getCompoundConstruct({});
+  ASSERT_EQ(C6, OMPD_unknown);
+  Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+  ASSERT_EQ(C7, OMPD_parallel_for_simd);
+}
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index e0edf1720f8ac5..34b517e816a243 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -20,6 +20,9 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+#include <numeric>
+#include <vector>
+
 using namespace llvm;
 
 namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
 };
 } // namespace
 
-// Generate enum class
+// Generate enum class. Entries are emitted in the order in which they appear
+// in the `Records` vector.
 static void GenerateEnumClass(const std::vector<Record *> &Records,
                               raw_ostream &OS, StringRef Enum, StringRef Prefix,
                               const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
   return HasDuplicateClausesInDirectives(getDirectives());
 }
 
+// Count the maximum number of leaf constituents per construct.
+static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
+  size_t MaxCount = 0;
+  for (Record *R : DirLang.getDirectives()) {
+    size_t Count = Directive{R}.getLeafConstructs().size();
+    MaxCount = std::max(MaxCount, Count);
+  }
+  return MaxCount;
+}
+
 // Generate the declaration section for the enumeration in the directive
 // language
 static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   if (DirLang.hasEnableBitmaskEnumInNamespace())
     OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
 
+  OS << "#include <cstddef>\n"; // for size_t
   OS << "\n";
   OS << "namespace llvm {\n";
   OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   OS << "bool isAllowedClauseForDirective(Directive D, "
      << "Clause C, unsigned Version);\n";
   OS << "\n";
-  OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
+  OS << "constexpr std::size_t getMaxLeafCount() { return "
+     << GetMaxLeafCount(DirLang) << "; }\n";
   OS << "Association getDirectiveAssociation(Directive D);\n";
   if (EnumHelperFuncs.length() > 0) {
     OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
   }
 }
 
+static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
+                                    const Record *Rec) {
+  Directive Dir{Rec};
+  return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
+          DirLang.getDirectivePrefix() + Dir.getFormattedName())
+      .str();
+}
+
+static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
+  return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
+      .str();
+}
+
 // Generate the isAllowedClauseForDirective function implementation.
 static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
                                     raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
   OS << "}\n"; // End of function isAllowedClauseForDirective
 }
 
-// Generate the getLeafConstructs function implementation.
-static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
-                                      raw_ostream &OS) {
-  auto getQualifiedName = [&](StringRef Formatted) -> std::string {
-    return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
-            "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
-        .str();
-  };
-
-  // For each list of leaves, generate a static local object, then
-  // return a reference to that object for a given directive, e.g.
+static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
+                          StringRef TableName) {
+  // The leaf constructs are emitted in a form of a 2D table, where each
+  // row corresponds to a directive (and there is a row for each directive).
   //
-  //   static ListTy leafConstructs_A_B = { A, B };
-  //   static ListTy leafConstructs_C_D_E = { C, D, E };
-  //   switch (Dir) {
-  //     case A_B:
-  //       return leafConstructs_A_B;
-  //     case C_D_E:
-  //       return leafConstructs_C_D_E;
-  //   }
-
-  // Map from a record that defines a directive to the name of the
-  // local object with the list of its leaves.
-  DenseMap<Record *, std::string> ListNames;
-
-  std::string DirectiveTypeName =
-      std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
-
-  OS << '\n';
-
-  // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
-  OS << "llvm::ArrayRef<" << DirectiveTypeName
-     << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
-     << DirectiveTypeName << " Dir) ";
-  OS << "{\n";
-
-  // Generate the locals.
-  for (Record *R : DirLang.getDirectives()) {
-    Directive Dir{R};
+  // Each row consists of
+  // - the id of the directive itself,
+  // - number of leaf constructs that will follow (0 for leafs),
+  // - ids of the leaf constructs (none if the directive is itself a leaf).
+  // The total number of these entries is at most MaxLeafCount+2. If this
+  // number is less than that, it is padded to occupy exactly MaxLeafCount+2
+  // entries in memory.
+  //
+  // The rows are stored in the table in the lexicographical order. This
+  // is intended to enable binary search when mapping a sequence of leafs
+  // back to the compound directive.
+  // The consequence of that is that in order to find a row corresponding
+  // to the given directive, we'd need to scan the first element of each
+  // row. To avoid this, an auxiliary ordering table is created, such that
+  //   row for Dir_A = table[auxiliary[Dir_A]].
+
+  std::vector<Record *> Directives = DirLang.getDirectives();
+  DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
+
+  for (auto [Idx, Rec] : llvm::enumerate(Directives))
+    DirId.insert(std::make_pair(Rec, Idx));
+
+  using LeafList = std::vector<int>;
+  int MaxLeafCount = GetMaxLeafCount(DirLang);
+
+  // The initial leaf table, rows order is same as directive order.
+  std::vector<LeafList> LeafTable(Directives.size());
+  for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
+    Directive Dir{Rec};
+    std::vector<Record *> Leaves = Dir.getLeafConstructs();
+
+    auto &List = LeafTable[Idx];
+    List.resize(MaxLeafCount + 2);
+    List[0] = Idx;           // The id of the directive itself.
+    List[1] = Leaves.size(); // The number of leaves to follow.
+
+    for (int I = 0; I != MaxLeafCount; ++I)
+      List[I + 2] =
+          static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
+  }
 
-    std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
-    if (LeafConstructs.empty())
-      continue;
+  // Avoid sorting the vector<vector> array, instead sort an index array.
+  // It will also be useful later to create the auxiliary indexing array.
+  std::vector<int> Ordering(Directives.size());
+  std::iota(Ordering.begin(), Ordering.end(), 0);
+
+  llvm::sort(Ordering, [&](int A, int B) {
+    auto &LeavesA = LeafTable[A];
+    auto &LeavesB = LeafTable[B];
+    if (LeavesA[1] == 0 && LeavesB[1] == 0)
+      return LeavesA[0] < LeavesB[0];
+    return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
+                                        &LeavesB[2], &LeavesB[2] + LeavesB[1]);
+  });
 
-    std::string ListName = "leafConstructs_" + Dir.getFormattedName();
-    OS << "  static const " << DirectiveTypeName << ' ' << ListName
-       << "[] = {\n";
-    for (Record *L : LeafConstructs) {
-      Directive LeafDir{L};
-      OS << "    " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+  // Emit the table
+
+  // The directives are emitted into a scoped enum, for which the underlying
+  // type is `int` (by default). The code above uses `int` to store directive
+  // ids, so make sure that we catch it when something changes in the
+  // underlying type.
+  std::string DirectiveType = GetDirectiveType(DirLang);
+  OS << "static_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
+
+  OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
+     << "[][" << MaxLeafCount + 2 << "] = {\n";
+  for (size_t I = 0, E = Directives.size(); I != E; ++I) {
+    auto &Leaves = LeafTable[Ordering[I]];
+    OS << "    " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
+    OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
+    for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
+      int Idx = Leaves[I];
+      if (Idx >= 0)
+        OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
+      else
+        OS << " static_cast<" << DirectiveType << ">(-1),";
     }
-    OS << "  };\n";
-    ListNames.insert(std::make_pair(R, std::move(ListName)));
-  }
-
-  if (!ListNames.empty())
     OS << '\n';
-  OS << "  switch (Dir) {\n";
-  for (Record *R : DirLang.getDirectives()) {
-    auto F = ListNames.find(R);
-    if (F == ListNames.end())
-      continue;
-
-    Directive Dir{R};
-    OS << "  case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
-    OS << "    return " << F->second << ";\n";
   }
-  OS << "  default:\n";
-  OS << "    return ArrayRef<" << DirectiveTypeName << ">{};\n";
-  OS << "  } // switch (Dir)\n";
-  OS << "}\n";
+  OS << "};\n\n";
+
+  // Emit the auxiliary index table: it's the inverse of the `Ordering`
+  // table above.
+  OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
+  OS << "   ";
+  std::vector<int> Reverse(Ordering.size());
+  for (int I = 0, E = Ordering.size(); I != E; ++I)
+    Reverse[Ordering[I]] = I;
+  for (int Idx : Reverse)
+    OS << ' ' << Idx << ',';
+  OS << "\n};\n";
 }
 
 static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
   // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
   GenerateIsAllowedClause(DirLang, OS);
 
-  // getLeafConstructs(Directive D)
-  GenerateGetLeafConstructs(DirLang, OS);
-
   // getDirectiveAssociation(Directive D)
   GenerateGetDirectiveAssociation(DirLang, OS);
+
+  // Leaf table for getLeafConstructs, etc.
+  EmitLeafTable(DirLang, OS, "LeafConstructTable");
 }
 
 // Generate the implemenation section for the enumeration in the directive

``````````

</details>


https://github.com/llvm/llvm-project/pull/87247


More information about the llvm-branch-commits mailing list