[llvm-branch-commits] [llvm] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct (PR #87247)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Apr 2 06:21:41 PDT 2024


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/87247

>From 291dc48d5e0b7e0ee39681a1276bd1d63f456b01 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 1 Apr 2024 10:07:45 -0500
Subject: [PATCH 1/2] [Frontend][OpenMP] Refactor getLeafConstructs, add
 getCompoundConstruct

Emit a special leaf constuct table in DirectiveEmitter.cpp, which will
allow both decomposition of a construct into leafs, and composition of
constituent constructs into a single compound construct (is possible).
---
 llvm/include/llvm/Frontend/OpenMP/OMP.h       |   7 +
 llvm/lib/Frontend/OpenMP/OMP.cpp              |  64 +++++-
 llvm/test/TableGen/directive1.td              |  19 +-
 llvm/test/TableGen/directive2.td              |  19 +-
 llvm/unittests/Frontend/CMakeLists.txt        |   1 +
 llvm/unittests/Frontend/OpenMPComposeTest.cpp |  41 ++++
 llvm/utils/TableGen/DirectiveEmitter.cpp      | 194 +++++++++++-------
 7 files changed, 258 insertions(+), 87 deletions(-)
 create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..4ed47f15dfe59e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,11 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h.inc"
 
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D);
+Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+} // namespace llvm::omp
+
 #endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..dd99d3d074fd1e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,12 +8,74 @@
 
 #include "llvm/Frontend/OpenMP/OMP.h"
 
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
 
+#include <algorithm>
+#include <iterator>
+#include <type_traits>
+
 using namespace llvm;
-using namespace omp;
+using namespace llvm::omp;
 
 #define GEN_DIRECTIVES_IMPL
 #include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D) {
+  auto Idx = static_cast<int>(D);
+  if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+    return {};
+  const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+  return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+}
+
+Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
+  if (Parts.empty())
+    return OMPD_unknown;
+
+  // Parts don't have to be leafs, so expand them into leafs first.
+  // Store the expanded leafs in the same format as rows in the leaf
+  // table (generated by tablegen).
+  SmallVector<Directive> RawLeafs(2);
+  for (Directive P : Parts) {
+    ArrayRef<Directive> Ls = getLeafConstructs(P);
+    if (!Ls.empty())
+      RawLeafs.append(Ls.begin(), Ls.end());
+    else
+      RawLeafs.push_back(P);
+  }
+
+  auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
+  if (GivenLeafs.size() == 1)
+    return GivenLeafs.front();
+  RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
+
+  auto Iter = llvm::lower_bound(
+      LeafConstructTable,
+      static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
+      [](const auto *RowA, const auto *RowB) {
+        const auto *BeginA = &RowA[2];
+        const auto *EndA = BeginA + static_cast<int>(RowA[1]);
+        const auto *BeginB = &RowB[2];
+        const auto *EndB = BeginB + static_cast<int>(RowB[1]);
+        if (BeginA == EndA && BeginB == EndB)
+          return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
+        return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
+      });
+
+  if (Iter == std::end(LeafConstructTable))
+    return OMPD_unknown;
+
+  // Verify that we got a match.
+  Directive Found = (*Iter)[0];
+  ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
+  if (FoundLeafs == GivenLeafs)
+    return Found;
+  return OMPD_unknown;
+}
+} // namespace llvm::omp
diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td
index 3184f625ead928..e6150210e7e9a4 100644
--- a/llvm/test/TableGen/directive1.td
+++ b/llvm/test/TableGen/directive1.td
@@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-EMPTY:
 // CHECK-NEXT:  #include "llvm/ADT/ArrayRef.h"
 // CHECK-NEXT:  #include "llvm/ADT/BitmaskEnum.h"
+// CHECK-NEXT:  #include <cstddef>
 // CHECK-EMPTY:
 // CHECK-NEXT:  namespace llvm {
 // CHECK-NEXT:  class StringRef;
@@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-NEXT:  /// Return true if \p C is a valid clause for \p D in version \p Version.
 // CHECK-NEXT:  bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
 // CHECK-EMPTY:
-// CHECK-NEXT:  llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT:  constexpr std::size_t getMaxLeafCount() { return 0; }
 // CHECK-NEXT:  Association getDirectiveAssociation(Directive D);
 // CHECK-NEXT:  AKind getAKind(StringRef);
 // CHECK-NEXT:  llvm::StringRef getTdlAKindName(AKind);
@@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> {
 // IMPL-NEXT:    llvm_unreachable("Invalid Tdl Directive kind");
 // IMPL-NEXT:  }
 // IMPL-EMPTY:
-// IMPL-NEXT:  llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT:    switch (Dir) {
-// IMPL-NEXT:    default:
-// IMPL-NEXT:      return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT:    } // switch (Dir)
-// IMPL-NEXT:  }
-// IMPL-EMPTY:
 // IMPL-NEXT:  llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
 // IMPL-NEXT:    switch (Dir) {
 // IMPL-NEXT:    case llvm::tdl::Directive::TDLD_dira:
@@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> {
 // IMPL-NEXT:    llvm_unreachable("Unexpected directive");
 // IMPL-NEXT:  }
 // IMPL-EMPTY:
+// IMPL-NEXT:  static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT:  {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT:    llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT:  };
+// IMPL-EMPTY:
+// IMPL-NEXT:  {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT:    0,
+// IMPL-NEXT:  };
+// IMPL-EMPTY:
 // IMPL-NEXT:  #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td
index d6fa4835c8dfdc..1750022e1f94ea 100644
--- a/llvm/test/TableGen/directive2.td
+++ b/llvm/test/TableGen/directive2.td
@@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-NEXT:  #define LLVM_Tdl_INC
 // CHECK-EMPTY:
 // CHECK-NEXT:  #include "llvm/ADT/ArrayRef.h"
+// CHECK-NEXT:  #include <cstddef>
 // CHECK-EMPTY:
 // CHECK-NEXT:  namespace llvm {
 // CHECK-NEXT:  class StringRef;
@@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-NEXT:  /// Return true if \p C is a valid clause for \p D in version \p Version.
 // CHECK-NEXT:  bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
 // CHECK-EMPTY:
-// CHECK-NEXT:  llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT:  constexpr std::size_t getMaxLeafCount() { return 0; }
 // CHECK-NEXT:  Association getDirectiveAssociation(Directive D);
 // CHECK-NEXT:  } // namespace tdl
 // CHECK-NEXT:  } // namespace llvm
@@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> {
 // IMPL-NEXT:    llvm_unreachable("Invalid Tdl Directive kind");
 // IMPL-NEXT:  }
 // IMPL-EMPTY:
-// IMPL-NEXT:  llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT:    switch (Dir) {
-// IMPL-NEXT:    default:
-// IMPL-NEXT:      return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT:    } // switch (Dir)
-// IMPL-NEXT:  }
-// IMPL-EMPTY:
 // IMPL-NEXT:  llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
 // IMPL-NEXT:    switch (Dir) {
 // IMPL-NEXT:    case llvm::tdl::Directive::TDLD_dira:
@@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> {
 // IMPL-NEXT:    llvm_unreachable("Unexpected directive");
 // IMPL-NEXT:  }
 // IMPL-EMPTY:
+// IMPL-NEXT:  static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT:  {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT:    llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT:  };
+// IMPL-EMPTY:
+// IMPL-NEXT:  {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT:    0,
+// IMPL-NEXT:  };
+// IMPL-EMPTY:
 // IMPL-NEXT:  #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index c6f60142d6276a..ddb6a16cbb984e 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
   OpenMPContextTest.cpp
   OpenMPIRBuilderTest.cpp
   OpenMPParsingTest.cpp
+  OpenMPComposeTest.cpp
 
   DEPENDS
   acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
new file mode 100644
index 00000000000000..29b1be4eb3432c
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -0,0 +1,41 @@
+//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace llvm::omp;
+
+TEST(Composition, GetLeafConstructs) {
+  ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
+  ASSERT_EQ(L1, (ArrayRef<Directive>{}));
+  ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
+  ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
+  ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
+  ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
+}
+
+TEST(Composition, GetCompoundConstruct) {
+  Directive C1 =
+      getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
+  ASSERT_EQ(C1, OMPD_target_teams_distribute);
+  Directive C2 = getCompoundConstruct({OMPD_target});
+  ASSERT_EQ(C2, OMPD_target);
+  Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
+  ASSERT_EQ(C3, OMPD_unknown);
+  Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+  ASSERT_EQ(C4, OMPD_target_teams_distribute);
+  Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+  ASSERT_EQ(C5, OMPD_target_teams_distribute);
+  Directive C6 = getCompoundConstruct({});
+  ASSERT_EQ(C6, OMPD_unknown);
+  Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+  ASSERT_EQ(C7, OMPD_parallel_for_simd);
+}
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index e0edf1720f8ac5..2d2b7748491897 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -20,6 +20,9 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+#include <numeric>
+#include <vector>
+
 using namespace llvm;
 
 namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
 };
 } // namespace
 
-// Generate enum class
+// Generate enum class. Entries are emitted in the order in which they appear
+// in the `Records` vector.
 static void GenerateEnumClass(const std::vector<Record *> &Records,
                               raw_ostream &OS, StringRef Enum, StringRef Prefix,
                               const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
   return HasDuplicateClausesInDirectives(getDirectives());
 }
 
+// Count the maximum number of leaf constituents per construct.
+static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
+  size_t MaxCount = 0;
+  for (Record *R : DirLang.getDirectives()) {
+    size_t Count = Directive{R}.getLeafConstructs().size();
+    MaxCount = std::max(MaxCount, Count);
+  }
+  return MaxCount;
+}
+
 // Generate the declaration section for the enumeration in the directive
 // language
 static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   if (DirLang.hasEnableBitmaskEnumInNamespace())
     OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
 
+  OS << "#include <cstddef>\n"; // for size_t
   OS << "\n";
   OS << "namespace llvm {\n";
   OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   OS << "bool isAllowedClauseForDirective(Directive D, "
      << "Clause C, unsigned Version);\n";
   OS << "\n";
-  OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
+  OS << "constexpr std::size_t getMaxLeafCount() { return "
+     << GetMaxLeafCount(DirLang) << "; }\n";
   OS << "Association getDirectiveAssociation(Directive D);\n";
   if (EnumHelperFuncs.length() > 0) {
     OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
   }
 }
 
+static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
+                                    const Record *Rec) {
+  Directive Dir{Rec};
+  return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
+          DirLang.getDirectivePrefix() + Dir.getFormattedName())
+      .str();
+}
+
+static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
+  return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
+      .str();
+}
+
 // Generate the isAllowedClauseForDirective function implementation.
 static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
                                     raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
   OS << "}\n"; // End of function isAllowedClauseForDirective
 }
 
-// Generate the getLeafConstructs function implementation.
-static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
-                                      raw_ostream &OS) {
-  auto getQualifiedName = [&](StringRef Formatted) -> std::string {
-    return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
-            "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
-        .str();
-  };
-
-  // For each list of leaves, generate a static local object, then
-  // return a reference to that object for a given directive, e.g.
+static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
+                          StringRef TableName) {
+  // The leaf constructs are emitted in a form of a 2D table, where each
+  // row corresponds to a directive (and there is a row for each directive).
   //
-  //   static ListTy leafConstructs_A_B = { A, B };
-  //   static ListTy leafConstructs_C_D_E = { C, D, E };
-  //   switch (Dir) {
-  //     case A_B:
-  //       return leafConstructs_A_B;
-  //     case C_D_E:
-  //       return leafConstructs_C_D_E;
-  //   }
-
-  // Map from a record that defines a directive to the name of the
-  // local object with the list of its leaves.
-  DenseMap<Record *, std::string> ListNames;
-
-  std::string DirectiveTypeName =
-      std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
-
-  OS << '\n';
-
-  // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
-  OS << "llvm::ArrayRef<" << DirectiveTypeName
-     << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
-     << DirectiveTypeName << " Dir) ";
-  OS << "{\n";
-
-  // Generate the locals.
-  for (Record *R : DirLang.getDirectives()) {
-    Directive Dir{R};
+  // Each row consists of
+  // - the id of the directive itself,
+  // - number of leaf constructs that will follow (0 for leafs),
+  // - ids of the leaf constructs (none if the directive is itself a leaf).
+  // The total number of these entries is at most MaxLeafCount+2. If this
+  // number is less than that, it is padded to occupy exactly MaxLeafCount+2
+  // entries in memory.
+  //
+  // The rows are stored in the table in the lexicographical order. This
+  // is intended to enable binary search when mapping a sequence of leafs
+  // back to the compound directive.
+  // The consequence of that is that in order to find a row corresponding
+  // to the given directive, we'd need to scan the first element of each
+  // row. To avoid this, an auxiliary ordering table is created, such that
+  //   row for Dir_A = table[auxiliary[Dir_A]].
+
+  std::vector<Record *> Directives = DirLang.getDirectives();
+  DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
+
+  for (auto [Idx, Rec] : llvm::enumerate(Directives))
+    DirId.insert(std::make_pair(Rec, Idx));
+
+  using LeafList = std::vector<int>;
+  int MaxLeafCount = GetMaxLeafCount(DirLang);
+
+  // The initial leaf table, rows order is same as directive order.
+  std::vector<LeafList> LeafTable(Directives.size());
+  for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
+    Directive Dir{Rec};
+    std::vector<Record *> Leaves = Dir.getLeafConstructs();
+
+    auto &List = LeafTable[Idx];
+    List.resize(MaxLeafCount + 2);
+    List[0] = Idx;           // The id of the directive itself.
+    List[1] = Leaves.size(); // The number of leaves to follow.
+
+    for (int I = 0; I != MaxLeafCount; ++I)
+      List[I + 2] =
+          static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
+  }
 
-    std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
-    if (LeafConstructs.empty())
-      continue;
+  // Avoid sorting the vector<vector> array, instead sort an index array.
+  // It will also be useful later to create the auxiliary indexing array.
+  std::vector<int> Ordering(Directives.size());
+  std::iota(Ordering.begin(), Ordering.end(), 0);
+
+  llvm::sort(Ordering, [&](int A, int B) {
+    auto &LeavesA = LeafTable[A];
+    auto &LeavesB = LeafTable[B];
+    if (LeavesA[1] == 0 && LeavesB[1] == 0)
+      return LeavesA[0] < LeavesB[0];
+    return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
+                                        &LeavesB[2], &LeavesB[2] + LeavesB[1]);
+  });
 
-    std::string ListName = "leafConstructs_" + Dir.getFormattedName();
-    OS << "  static const " << DirectiveTypeName << ' ' << ListName
-       << "[] = {\n";
-    for (Record *L : LeafConstructs) {
-      Directive LeafDir{L};
-      OS << "    " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+  // Emit the table
+
+  // The directives are emitted into a scoped enum, for which the underlying
+  // type is `int` (by default). The code above uses `int` to store directive
+  // ids, so make sure that we catch it when something changes in the
+  // underlying type.
+  std::string DirectiveType = GetDirectiveType(DirLang);
+  OS << "\nstatic_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
+
+  OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
+     << "[][" << MaxLeafCount + 2 << "] = {\n";
+  for (size_t I = 0, E = Directives.size(); I != E; ++I) {
+    auto &Leaves = LeafTable[Ordering[I]];
+    OS << "    " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
+    OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
+    for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
+      int Idx = Leaves[I];
+      if (Idx >= 0)
+        OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
+      else
+        OS << " static_cast<" << DirectiveType << ">(-1),";
     }
-    OS << "  };\n";
-    ListNames.insert(std::make_pair(R, std::move(ListName)));
-  }
-
-  if (!ListNames.empty())
     OS << '\n';
-  OS << "  switch (Dir) {\n";
-  for (Record *R : DirLang.getDirectives()) {
-    auto F = ListNames.find(R);
-    if (F == ListNames.end())
-      continue;
-
-    Directive Dir{R};
-    OS << "  case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
-    OS << "    return " << F->second << ";\n";
   }
-  OS << "  default:\n";
-  OS << "    return ArrayRef<" << DirectiveTypeName << ">{};\n";
-  OS << "  } // switch (Dir)\n";
-  OS << "}\n";
+  OS << "};\n\n";
+
+  // Emit the auxiliary index table: it's the inverse of the `Ordering`
+  // table above.
+  OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
+  OS << "   ";
+  std::vector<int> Reverse(Ordering.size());
+  for (int I = 0, E = Ordering.size(); I != E; ++I)
+    Reverse[Ordering[I]] = I;
+  for (int Idx : Reverse)
+    OS << ' ' << Idx << ',';
+  OS << "\n};\n";
 }
 
 static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
   // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
   GenerateIsAllowedClause(DirLang, OS);
 
-  // getLeafConstructs(Directive D)
-  GenerateGetLeafConstructs(DirLang, OS);
-
   // getDirectiveAssociation(Directive D)
   GenerateGetDirectiveAssociation(DirLang, OS);
+
+  // Leaf table for getLeafConstructs, etc.
+  EmitLeafTable(DirLang, OS, "LeafConstructTable");
 }
 
 // Generate the implemenation section for the enumeration in the directive

>From 0d92781c7a52ed2fbab33ae6e7b3dae61cfd42ae Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 2 Apr 2024 08:20:15 -0500
Subject: [PATCH 2/2] Address review comments

---
 llvm/lib/Frontend/OpenMP/OMP.cpp              | 10 ++++++++--
 llvm/unittests/Frontend/OpenMPComposeTest.cpp | 10 ++++------
 2 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index dd99d3d074fd1e..7504c9076fde1b 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -27,8 +27,8 @@ using namespace llvm::omp;
 
 namespace llvm::omp {
 ArrayRef<Directive> getLeafConstructs(Directive D) {
-  auto Idx = static_cast<int>(D);
-  if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+  auto Idx = static_cast<std::size_t>(D);
+  if (Idx >= Directive_enumSize)
     return {};
   const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
   return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
@@ -50,6 +50,12 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
       RawLeafs.push_back(P);
   }
 
+  // RawLeafs will be used as key in the binary search. The search doesn't
+  // guarantee that the exact same entry will be found (since RawLeafs may
+  // not correspond to any compound directive). Because of that, we will
+  // need to compare the search result with the given set of leafs.
+  // Also, if there is only one leaf in the list, it corresponds to itself,
+  // no search is necessary.
   auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
   if (GivenLeafs.size() == 1)
     return GivenLeafs.front();
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
index 29b1be4eb3432c..c3e0880ece8641 100644
--- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -32,10 +32,8 @@ TEST(Composition, GetCompoundConstruct) {
   ASSERT_EQ(C3, OMPD_unknown);
   Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
   ASSERT_EQ(C4, OMPD_target_teams_distribute);
-  Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
-  ASSERT_EQ(C5, OMPD_target_teams_distribute);
-  Directive C6 = getCompoundConstruct({});
-  ASSERT_EQ(C6, OMPD_unknown);
-  Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
-  ASSERT_EQ(C7, OMPD_parallel_for_simd);
+  Directive C5 = getCompoundConstruct({});
+  ASSERT_EQ(C5, OMPD_unknown);
+  Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+  ASSERT_EQ(C6, OMPD_parallel_for_simd);
 }



More information about the llvm-branch-commits mailing list