[llvm-branch-commits] [llvm] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct (PR #87247)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Apr 1 08:22:41 PDT 2024
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/87247
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible).
>From 2fec99813013adf1ab6b262132ddebe4356ce643 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 1 Apr 2024 10:07:45 -0500
Subject: [PATCH] [Frontend][OpenMP] Refactor getLeafConstructs, add
getCompoundConstruct
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will
allow both decomposition of a construct into leafs, and composition of
constituent constructs into a single compound construct (is possible).
---
llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 +
llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++-
llvm/unittests/Frontend/CMakeLists.txt | 1 +
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 40 ++++
llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++-------
5 files changed, 235 insertions(+), 71 deletions(-)
create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..4ed47f15dfe59e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,11 @@
#include "llvm/Frontend/OpenMP/OMP.h.inc"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D);
+Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+} // namespace llvm::omp
+
#endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..dd99d3d074fd1e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,12 +8,74 @@
#include "llvm/Frontend/OpenMP/OMP.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"
+#include <algorithm>
+#include <iterator>
+#include <type_traits>
+
using namespace llvm;
-using namespace omp;
+using namespace llvm::omp;
#define GEN_DIRECTIVES_IMPL
#include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D) {
+ auto Idx = static_cast<int>(D);
+ if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+ return {};
+ const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+ return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+}
+
+Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
+ if (Parts.empty())
+ return OMPD_unknown;
+
+ // Parts don't have to be leafs, so expand them into leafs first.
+ // Store the expanded leafs in the same format as rows in the leaf
+ // table (generated by tablegen).
+ SmallVector<Directive> RawLeafs(2);
+ for (Directive P : Parts) {
+ ArrayRef<Directive> Ls = getLeafConstructs(P);
+ if (!Ls.empty())
+ RawLeafs.append(Ls.begin(), Ls.end());
+ else
+ RawLeafs.push_back(P);
+ }
+
+ auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
+ if (GivenLeafs.size() == 1)
+ return GivenLeafs.front();
+ RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
+
+ auto Iter = llvm::lower_bound(
+ LeafConstructTable,
+ static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
+ [](const auto *RowA, const auto *RowB) {
+ const auto *BeginA = &RowA[2];
+ const auto *EndA = BeginA + static_cast<int>(RowA[1]);
+ const auto *BeginB = &RowB[2];
+ const auto *EndB = BeginB + static_cast<int>(RowB[1]);
+ if (BeginA == EndA && BeginB == EndB)
+ return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
+ return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
+ });
+
+ if (Iter == std::end(LeafConstructTable))
+ return OMPD_unknown;
+
+ // Verify that we got a match.
+ Directive Found = (*Iter)[0];
+ ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
+ if (FoundLeafs == GivenLeafs)
+ return Found;
+ return OMPD_unknown;
+}
+} // namespace llvm::omp
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index c6f60142d6276a..ddb6a16cbb984e 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPContextTest.cpp
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
+ OpenMPComposeTest.cpp
DEPENDS
acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
new file mode 100644
index 00000000000000..2dc35aca8842e9
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -0,0 +1,40 @@
+//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace llvm::omp;
+
+TEST(Composition, GetLeafConstructs) {
+ ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
+ ASSERT_EQ(L1, (ArrayRef<Directive>{}));
+ ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
+ ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
+ ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
+ ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
+}
+
+TEST(Composition, GetCompoundConstruct) {
+ Directive C1 = getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
+ ASSERT_EQ(C1, OMPD_target_teams_distribute);
+ Directive C2 = getCompoundConstruct({OMPD_target});
+ ASSERT_EQ(C2, OMPD_target);
+ Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
+ ASSERT_EQ(C3, OMPD_unknown);
+ Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C4, OMPD_target_teams_distribute);
+ Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C5, OMPD_target_teams_distribute);
+ Directive C6 = getCompoundConstruct({});
+ ASSERT_EQ(C6, OMPD_unknown);
+ Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+ ASSERT_EQ(C7, OMPD_parallel_for_simd);
+}
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index e0edf1720f8ac5..34b517e816a243 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -20,6 +20,9 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
+#include <numeric>
+#include <vector>
+
using namespace llvm;
namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
};
} // namespace
-// Generate enum class
+// Generate enum class. Entries are emitted in the order in which they appear
+// in the `Records` vector.
static void GenerateEnumClass(const std::vector<Record *> &Records,
raw_ostream &OS, StringRef Enum, StringRef Prefix,
const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
return HasDuplicateClausesInDirectives(getDirectives());
}
+// Count the maximum number of leaf constituents per construct.
+static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
+ size_t MaxCount = 0;
+ for (Record *R : DirLang.getDirectives()) {
+ size_t Count = Directive{R}.getLeafConstructs().size();
+ MaxCount = std::max(MaxCount, Count);
+ }
+ return MaxCount;
+}
+
// Generate the declaration section for the enumeration in the directive
// language
static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
if (DirLang.hasEnableBitmaskEnumInNamespace())
OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
+ OS << "#include <cstddef>\n"; // for size_t
OS << "\n";
OS << "namespace llvm {\n";
OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
OS << "bool isAllowedClauseForDirective(Directive D, "
<< "Clause C, unsigned Version);\n";
OS << "\n";
- OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
+ OS << "constexpr std::size_t getMaxLeafCount() { return "
+ << GetMaxLeafCount(DirLang) << "; }\n";
OS << "Association getDirectiveAssociation(Directive D);\n";
if (EnumHelperFuncs.length() > 0) {
OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
}
}
+static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
+ const Record *Rec) {
+ Directive Dir{Rec};
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
+ DirLang.getDirectivePrefix() + Dir.getFormattedName())
+ .str();
+}
+
+static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
+ .str();
+}
+
// Generate the isAllowedClauseForDirective function implementation.
static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
OS << "}\n"; // End of function isAllowedClauseForDirective
}
-// Generate the getLeafConstructs function implementation.
-static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
- raw_ostream &OS) {
- auto getQualifiedName = [&](StringRef Formatted) -> std::string {
- return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
- "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
- .str();
- };
-
- // For each list of leaves, generate a static local object, then
- // return a reference to that object for a given directive, e.g.
+static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
+ StringRef TableName) {
+ // The leaf constructs are emitted in a form of a 2D table, where each
+ // row corresponds to a directive (and there is a row for each directive).
//
- // static ListTy leafConstructs_A_B = { A, B };
- // static ListTy leafConstructs_C_D_E = { C, D, E };
- // switch (Dir) {
- // case A_B:
- // return leafConstructs_A_B;
- // case C_D_E:
- // return leafConstructs_C_D_E;
- // }
-
- // Map from a record that defines a directive to the name of the
- // local object with the list of its leaves.
- DenseMap<Record *, std::string> ListNames;
-
- std::string DirectiveTypeName =
- std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
-
- OS << '\n';
-
- // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
- OS << "llvm::ArrayRef<" << DirectiveTypeName
- << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
- << DirectiveTypeName << " Dir) ";
- OS << "{\n";
-
- // Generate the locals.
- for (Record *R : DirLang.getDirectives()) {
- Directive Dir{R};
+ // Each row consists of
+ // - the id of the directive itself,
+ // - number of leaf constructs that will follow (0 for leafs),
+ // - ids of the leaf constructs (none if the directive is itself a leaf).
+ // The total number of these entries is at most MaxLeafCount+2. If this
+ // number is less than that, it is padded to occupy exactly MaxLeafCount+2
+ // entries in memory.
+ //
+ // The rows are stored in the table in the lexicographical order. This
+ // is intended to enable binary search when mapping a sequence of leafs
+ // back to the compound directive.
+ // The consequence of that is that in order to find a row corresponding
+ // to the given directive, we'd need to scan the first element of each
+ // row. To avoid this, an auxiliary ordering table is created, such that
+ // row for Dir_A = table[auxiliary[Dir_A]].
+
+ std::vector<Record *> Directives = DirLang.getDirectives();
+ DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
+
+ for (auto [Idx, Rec] : llvm::enumerate(Directives))
+ DirId.insert(std::make_pair(Rec, Idx));
+
+ using LeafList = std::vector<int>;
+ int MaxLeafCount = GetMaxLeafCount(DirLang);
+
+ // The initial leaf table, rows order is same as directive order.
+ std::vector<LeafList> LeafTable(Directives.size());
+ for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
+ Directive Dir{Rec};
+ std::vector<Record *> Leaves = Dir.getLeafConstructs();
+
+ auto &List = LeafTable[Idx];
+ List.resize(MaxLeafCount + 2);
+ List[0] = Idx; // The id of the directive itself.
+ List[1] = Leaves.size(); // The number of leaves to follow.
+
+ for (int I = 0; I != MaxLeafCount; ++I)
+ List[I + 2] =
+ static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
+ }
- std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
- if (LeafConstructs.empty())
- continue;
+ // Avoid sorting the vector<vector> array, instead sort an index array.
+ // It will also be useful later to create the auxiliary indexing array.
+ std::vector<int> Ordering(Directives.size());
+ std::iota(Ordering.begin(), Ordering.end(), 0);
+
+ llvm::sort(Ordering, [&](int A, int B) {
+ auto &LeavesA = LeafTable[A];
+ auto &LeavesB = LeafTable[B];
+ if (LeavesA[1] == 0 && LeavesB[1] == 0)
+ return LeavesA[0] < LeavesB[0];
+ return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
+ &LeavesB[2], &LeavesB[2] + LeavesB[1]);
+ });
- std::string ListName = "leafConstructs_" + Dir.getFormattedName();
- OS << " static const " << DirectiveTypeName << ' ' << ListName
- << "[] = {\n";
- for (Record *L : LeafConstructs) {
- Directive LeafDir{L};
- OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+ // Emit the table
+
+ // The directives are emitted into a scoped enum, for which the underlying
+ // type is `int` (by default). The code above uses `int` to store directive
+ // ids, so make sure that we catch it when something changes in the
+ // underlying type.
+ std::string DirectiveType = GetDirectiveType(DirLang);
+ OS << "static_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
+
+ OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
+ << "[][" << MaxLeafCount + 2 << "] = {\n";
+ for (size_t I = 0, E = Directives.size(); I != E; ++I) {
+ auto &Leaves = LeafTable[Ordering[I]];
+ OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
+ OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
+ for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
+ int Idx = Leaves[I];
+ if (Idx >= 0)
+ OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
+ else
+ OS << " static_cast<" << DirectiveType << ">(-1),";
}
- OS << " };\n";
- ListNames.insert(std::make_pair(R, std::move(ListName)));
- }
-
- if (!ListNames.empty())
OS << '\n';
- OS << " switch (Dir) {\n";
- for (Record *R : DirLang.getDirectives()) {
- auto F = ListNames.find(R);
- if (F == ListNames.end())
- continue;
-
- Directive Dir{R};
- OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
- OS << " return " << F->second << ";\n";
}
- OS << " default:\n";
- OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n";
- OS << " } // switch (Dir)\n";
- OS << "}\n";
+ OS << "};\n\n";
+
+ // Emit the auxiliary index table: it's the inverse of the `Ordering`
+ // table above.
+ OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
+ OS << " ";
+ std::vector<int> Reverse(Ordering.size());
+ for (int I = 0, E = Ordering.size(); I != E; ++I)
+ Reverse[Ordering[I]] = I;
+ for (int Idx : Reverse)
+ OS << ' ' << Idx << ',';
+ OS << "\n};\n";
}
static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
GenerateIsAllowedClause(DirLang, OS);
- // getLeafConstructs(Directive D)
- GenerateGetLeafConstructs(DirLang, OS);
-
// getDirectiveAssociation(Directive D)
GenerateGetDirectiveAssociation(DirLang, OS);
+
+ // Leaf table for getLeafConstructs, etc.
+ EmitLeafTable(DirLang, OS, "LeafConstructTable");
}
// Generate the implemenation section for the enumeration in the directive
More information about the llvm-branch-commits
mailing list