[Mlir-commits] [mlir] e6c5e6e - [MLIR, OpenMP] Lowering of parallel operation: proc_bind clause 2/n

Kiran Chandramohan llvmlistbot at llvm.org
Wed Aug 12 00:04:04 PDT 2020

Author: Kiran Chandramohan
Date: 2020-08-12T08:03:13+01:00
New Revision: e6c5e6efd080ab80f133a6591a7e4f0b617c291f

URL: https://github.com/llvm/llvm-project/commit/e6c5e6efd080ab80f133a6591a7e4f0b617c291f
DIFF: https://github.com/llvm/llvm-project/commit/e6c5e6efd080ab80f133a6591a7e4f0b617c291f.diff

LOG: [MLIR,OpenMP] Lowering of parallel operation: proc_bind clause 2/n

This patch adds the translation of the proc_bind clause in a
parallel operation.

The values that can be specified for the proc_bind clause are
specified in the OMP.td tablegen file in the llvm/Frontend/OpenMP
directory. From this single source of truth enumeration for
proc_bind is generated in llvm and mlir (used in specification of
the parallel Operation in the OpenMP dialect). A function to return
the enum value from the string representation is also generated.
A new header file (DirectiveEmitter.h) containing definitions of
classes directive, clause, clauseval etc is created so that it can
be used in mlir as well.

Reviewers: clementval, jdoerfert, DavidTruby

Differential Revision: https://reviews.llvm.org/D84347




diff  --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
index 13b1edef0580..b691bf8c3a7b 100644
--- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
+++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td
@@ -51,6 +51,21 @@ class DirectiveLanguage {
   string flangClauseBaseClass = "";
+// Information about values accepted by enum-like clauses
+class ClauseVal<string n, int v, bit uv> {
+  // Name of the clause value.
+  string name = n;
+  // Integer value of the clause.
+  int value = v;
+  // Can user specify this value?
+  bit isUserValue = uv;
+  // Set clause value used by default when unknown.
+  bit isDefault = 0;
 // Information about a specific clause.
 class Clause<string c> {
   // Name of the clause.
@@ -75,11 +90,17 @@ class Clause<string c> {
   // If set to 1, value is optional. Not optional by default.
   bit isValueOptional = 0;
+  // Name of enum when there is a list of allowed clause values.
+  string enumClauseValue = "";
+  // List of allowed clause values
+  list<ClauseVal> allowedClauseValues = [];
   // Is clause implicit? If clause is set as implicit, the default kind will
   // be return in get<LanguageName>ClauseKind instead of their own kind.
   bit isImplicit = 0;
-  // Set directive used by default when unknown. Function returning the kind
+  // Set clause used by default when unknown. Function returning the kind
   // of enumeration will use this clause as the default.
   bit isDefault = 0;

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 24111c9e4701..2e392156766c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -99,9 +99,22 @@ def OMPC_CopyPrivate : Clause<"copyprivate"> {
   let clangClass = "OMPCopyprivateClause";
   let flangClassValue = "OmpObjectList";
+def OMP_PROC_BIND_master : ClauseVal<"master",2,1> {}
+def OMP_PROC_BIND_close : ClauseVal<"close",3,1> {}
+def OMP_PROC_BIND_spread : ClauseVal<"spread",4,1> {}
+def OMP_PROC_BIND_default : ClauseVal<"default",5,0> {}
+def OMP_PROC_BIND_unknown : ClauseVal<"unknown",6,0> { let isDefault = 1; }
 def OMPC_ProcBind : Clause<"proc_bind"> {
   let clangClass = "OMPProcBindClause";
   let flangClass = "OmpProcBindClause";
+  let enumClauseValue = "ProcBindKind";
+  let allowedClauseValues = [
+    OMP_PROC_BIND_master,
+    OMP_PROC_BIND_close,
+    OMP_PROC_BIND_spread,
+    OMP_PROC_BIND_default,
+    OMP_PROC_BIND_unknown
+  ];
 def OMPC_Schedule : Clause<"schedule"> {
   let clangClass = "OMPScheduleClause";

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
index d171d0a2b6c4..f612fb3cd948 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -68,16 +68,6 @@ enum class DefaultKind {
   constexpr auto Enum = omp::DefaultKind::Enum;
 #include "llvm/Frontend/OpenMP/OMPKinds.def"
-/// IDs for the 
diff erent proc bind kinds.
-enum class ProcBindKind {
-#define OMP_PROC_BIND_KIND(Enum, Str, Value) Enum = Value,
-#include "llvm/Frontend/OpenMP/OMPKinds.def"
-#define OMP_PROC_BIND_KIND(Enum, ...)                                          \
-  constexpr auto Enum = omp::ProcBindKind::Enum;
-#include "llvm/Frontend/OpenMP/OMPKinds.def"
 /// IDs for all omp runtime library ident_t flag encodings (see
 /// their defintion in openmp/runtime/src/kmp.h).
 enum class IdentFlag {

diff  --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h
new file mode 100644
index 000000000000..8a7664afa98b
--- /dev/null
+++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h
@@ -0,0 +1,188 @@
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/TableGen/Record.h"
+namespace llvm {
+// Wrapper class that contains DirectiveLanguage's information defined in
+// DirectiveBase.td and provides helper methods for accessing it.
+class DirectiveLanguage {
+  explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
+  StringRef getName() const { return Def->getValueAsString("name"); }
+  StringRef getCppNamespace() const {
+    return Def->getValueAsString("cppNamespace");
+  }
+  StringRef getDirectivePrefix() const {
+    return Def->getValueAsString("directivePrefix");
+  }
+  StringRef getClausePrefix() const {
+    return Def->getValueAsString("clausePrefix");
+  }
+  StringRef getIncludeHeader() const {
+    return Def->getValueAsString("includeHeader");
+  }
+  StringRef getClauseEnumSetClass() const {
+    return Def->getValueAsString("clauseEnumSetClass");
+  }
+  StringRef getFlangClauseBaseClass() const {
+    return Def->getValueAsString("flangClauseBaseClass");
+  }
+  bool hasMakeEnumAvailableInNamespace() const {
+    return Def->getValueAsBit("makeEnumAvailableInNamespace");
+  }
+  bool hasEnableBitmaskEnumInNamespace() const {
+    return Def->getValueAsBit("enableBitmaskEnumInNamespace");
+  }
+  const llvm::Record *Def;
+// Base record class used for Directive and Clause class defined in
+// DirectiveBase.td.
+class BaseRecord {
+  explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
+  StringRef getName() const { return Def->getValueAsString("name"); }
+  StringRef getAlternativeName() const {
+    return Def->getValueAsString("alternativeName");
+  }
+  // Returns the name of the directive formatted for output. Whitespace are
+  // replaced with underscores.
+  std::string getFormattedName() {
+    StringRef Name = Def->getValueAsString("name");
+    std::string N = Name.str();
+    std::replace(N.begin(), N.end(), ' ', '_');
+    return N;
+  }
+  bool isDefault() const { return Def->getValueAsBit("isDefault"); }
+  const llvm::Record *Def;
+// Wrapper class that contains a Directive's information defined in
+// DirectiveBase.td and provides helper methods for accessing it.
+class Directive : public BaseRecord {
+  explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
+  std::vector<Record *> getAllowedClauses() const {
+    return Def->getValueAsListOfDefs("allowedClauses");
+  }
+  std::vector<Record *> getAllowedOnceClauses() const {
+    return Def->getValueAsListOfDefs("allowedOnceClauses");
+  }
+  std::vector<Record *> getAllowedExclusiveClauses() const {
+    return Def->getValueAsListOfDefs("allowedExclusiveClauses");
+  }
+  std::vector<Record *> getRequiredClauses() const {
+    return Def->getValueAsListOfDefs("requiredClauses");
+  }
+// Wrapper class that contains Clause's information defined in DirectiveBase.td
+// and provides helper methods for accessing it.
+class Clause : public BaseRecord {
+  explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
+  // Optional field.
+  StringRef getClangClass() const {
+    return Def->getValueAsString("clangClass");
+  }
+  // Optional field.
+  StringRef getFlangClass() const {
+    return Def->getValueAsString("flangClass");
+  }
+  // Optional field.
+  StringRef getFlangClassValue() const {
+    return Def->getValueAsString("flangClassValue");
+  }
+  // Get the formatted name for Flang parser class. The generic formatted class
+  // name is constructed from the name were the first letter of each word is
+  // captitalized and the underscores are removed.
+  // ex: async -> Async
+  //     num_threads -> NumThreads
+  std::string getFormattedParserClassName() {
+    StringRef Name = Def->getValueAsString("name");
+    std::string N = Name.str();
+    bool Cap = true;
+    std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
+      if (Cap == true) {
+        C = llvm::toUpper(C);
+        Cap = false;
+      } else if (C == '_') {
+        Cap = true;
+      }
+      return C;
+    });
+    N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
+    return N;
+  }
+  // Optional field.
+  StringRef getEnumName() const {
+    return Def->getValueAsString("enumClauseValue");
+  }
+  std::vector<Record *> getClauseVals() const {
+    return Def->getValueAsListOfDefs("allowedClauseValues");
+  }
+  bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
+  bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
+// Wrapper class that contains VersionedClause's information defined in
+// DirectiveBase.td and provides helper methods for accessing it.
+class VersionedClause {
+  explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
+  // Return the specific clause record wrapped in the Clause class.
+  Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
+  int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
+  int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
+  const llvm::Record *Def;
+class ClauseVal : public BaseRecord {
+  explicit ClauseVal(const llvm::Record *Def) : BaseRecord(Def) {}
+  int getValue() const { return Def->getValueAsInt("value"); }
+  bool isUserVisible() const { return Def->getValueAsBit("isUserValue"); }
+} // namespace llvm

diff  --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td
index f647a03814ce..268999e3f0ce 100644
--- a/llvm/test/TableGen/directive1.td
+++ b/llvm/test/TableGen/directive1.td
@@ -15,9 +15,20 @@ def TestDirectiveLanguage : DirectiveLanguage {
   let flangClauseBaseClass = "TdlClause";
+def TDLCV_vala : ClauseVal<"vala",1,1> {}
+def TDLCV_valb : ClauseVal<"valb",2,1> {}
+def TDLCV_valc : ClauseVal<"valc",3,0> { let isDefault = 1; }
 def TDLC_ClauseA : Clause<"clausea"> {
   let flangClass = "TdlClauseA";
+  let enumClauseValue = "AKind";
+  let allowedClauseValues = [
+    TDLCV_vala,
+    TDLCV_valb,
+    TDLCV_valc
+  ];
 def TDLC_ClauseB : Clause<"clauseb"> {
   let flangClassValue = "IntExpr";
   let isValueOptional = 1;
@@ -61,6 +72,16 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-NEXT:  constexpr auto TDLC_clausea = llvm::tdl::Clause::TDLC_clausea;
 // CHECK-NEXT:  constexpr auto TDLC_clauseb = llvm::tdl::Clause::TDLC_clauseb;
+// CHECK-NEXT:  enum class AKind {
+// CHECK-NEXT:    TDLCV_vala=1,
+// CHECK-NEXT:    TDLCV_valb=2,
+// CHECK-NEXT:    TDLCV_valc=3,
+// CHECK-NEXT:  };
+// CHECK-NEXT:  constexpr auto TDLCV_vala = llvm::tdl::AKind::TDLCV_vala;
+// CHECK-NEXT:  constexpr auto TDLCV_valb = llvm::tdl::AKind::TDLCV_valb;
+// CHECK-NEXT:  constexpr auto TDLCV_valc = llvm::tdl::AKind::TDLCV_valc;
 // CHECK-NEXT:  // Enumeration helper functions
 // CHECK-NEXT:  Directive getTdlDirectiveKind(llvm::StringRef Str);
@@ -73,6 +94,8 @@ def TDL_DirA : Directive<"dira"> {
 // CHECK-NEXT:  /// Return true if \p C is a valid clause for \p D in version \p Version.
 // CHECK-NEXT:  bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
+// CHECK-NEXT:  AKind getAKind(StringRef);
 // CHECK-NEXT:  } // namespace tdl
 // CHECK-NEXT:  } // namespace llvm
 // CHECK-NEXT:  #endif // LLVM_Tdl_INC
@@ -116,6 +139,14 @@ def TDL_DirA : Directive<"dira"> {
 // IMPL-NEXT:    llvm_unreachable("Invalid Tdl Clause kind");
 // IMPL-NEXT:  }
+// IMPL-NEXT:  AKind llvm::tdl::getAKind(llvm::StringRef Str) {
+// IMPL-NEXT:    return llvm::StringSwitch<AKind>(Str)
+// IMPL-NEXT:      .Case("vala",TDLCV_vala)
+// IMPL-NEXT:      .Case("valb",TDLCV_valb)
+// IMPL-NEXT:      .Case("valc",TDLCV_valc)
+// IMPL-NEXT:      .Default(TDLCV_valc);
+// IMPL-NEXT:  }
 // IMPL-NEXT:  bool llvm::tdl::isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) {
 // IMPL-NEXT:    assert(unsigned(D) <= llvm::tdl::Directive_enumSize);
 // IMPL-NEXT:    assert(unsigned(C) <= llvm::tdl::Clause_enumSize);

diff  --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index 957bb0fc14b5..d89b17510113 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -11,15 +11,14 @@
+#include "llvm/TableGen/DirectiveEmitter.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 using namespace llvm;
 namespace {
@@ -41,165 +40,6 @@ class IfDefScope {
 namespace llvm {
-// Wrapper class that contains DirectiveLanguage's information defined in
-// DirectiveBase.td and provides helper methods for accessing it.
-class DirectiveLanguage {
-  explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {}
-  StringRef getName() const { return Def->getValueAsString("name"); }
-  StringRef getCppNamespace() const {
-    return Def->getValueAsString("cppNamespace");
-  }
-  StringRef getDirectivePrefix() const {
-    return Def->getValueAsString("directivePrefix");
-  }
-  StringRef getClausePrefix() const {
-    return Def->getValueAsString("clausePrefix");
-  }
-  StringRef getIncludeHeader() const {
-    return Def->getValueAsString("includeHeader");
-  }
-  StringRef getClauseEnumSetClass() const {
-    return Def->getValueAsString("clauseEnumSetClass");
-  }
-  StringRef getFlangClauseBaseClass() const {
-    return Def->getValueAsString("flangClauseBaseClass");
-  }
-  bool hasMakeEnumAvailableInNamespace() const {
-    return Def->getValueAsBit("makeEnumAvailableInNamespace");
-  }
-  bool hasEnableBitmaskEnumInNamespace() const {
-    return Def->getValueAsBit("enableBitmaskEnumInNamespace");
-  }
-  const llvm::Record *Def;
-// Base record class used for Directive and Clause class defined in
-// DirectiveBase.td.
-class BaseRecord {
-  explicit BaseRecord(const llvm::Record *Def) : Def(Def) {}
-  StringRef getName() const { return Def->getValueAsString("name"); }
-  StringRef getAlternativeName() const {
-    return Def->getValueAsString("alternativeName");
-  }
-  // Returns the name of the directive formatted for output. Whitespace are
-  // replaced with underscores.
-  std::string getFormattedName() {
-    StringRef Name = Def->getValueAsString("name");
-    std::string N = Name.str();
-    std::replace(N.begin(), N.end(), ' ', '_');
-    return N;
-  }
-  bool isDefault() const { return Def->getValueAsBit("isDefault"); }
-  const llvm::Record *Def;
-// Wrapper class that contains a Directive's information defined in
-// DirectiveBase.td and provides helper methods for accessing it.
-class Directive : public BaseRecord {
-  explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {}
-  std::vector<Record *> getAllowedClauses() const {
-    return Def->getValueAsListOfDefs("allowedClauses");
-  }
-  std::vector<Record *> getAllowedOnceClauses() const {
-    return Def->getValueAsListOfDefs("allowedOnceClauses");
-  }
-  std::vector<Record *> getAllowedExclusiveClauses() const {
-    return Def->getValueAsListOfDefs("allowedExclusiveClauses");
-  }
-  std::vector<Record *> getRequiredClauses() const {
-    return Def->getValueAsListOfDefs("requiredClauses");
-  }
-// Wrapper class that contains Clause's information defined in DirectiveBase.td
-// and provides helper methods for accessing it.
-class Clause : public BaseRecord {
-  explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {}
-  // Optional field.
-  StringRef getClangClass() const {
-    return Def->getValueAsString("clangClass");
-  }
-  // Optional field.
-  StringRef getFlangClass() const {
-    return Def->getValueAsString("flangClass");
-  }
-  // Optional field.
-  StringRef getFlangClassValue() const {
-    return Def->getValueAsString("flangClassValue");
-  }
-  // Get the formatted name for Flang parser class. The generic formatted class
-  // name is constructed from the name were the first letter of each word is
-  // captitalized and the underscores are removed.
-  // ex: async -> Async
-  //     num_threads -> NumThreads
-  std::string getFormattedParserClassName() {
-    StringRef Name = Def->getValueAsString("name");
-    std::string N = Name.str();
-    bool Cap = true;
-    std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) {
-      if (Cap == true) {
-        C = llvm::toUpper(C);
-        Cap = false;
-      } else if (C == '_') {
-        Cap = true;
-      }
-      return C;
-    });
-    N.erase(std::remove(N.begin(), N.end(), '_'), N.end());
-    return N;
-  }
-  bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); }
-  bool isImplict() const { return Def->getValueAsBit("isImplicit"); }
-// Wrapper class that contains VersionedClause's information defined in
-// DirectiveBase.td and provides helper methods for accessing it.
-class VersionedClause {
-  explicit VersionedClause(const llvm::Record *Def) : Def(Def) {}
-  // Return the specific clause record wrapped in the Clause class.
-  Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; }
-  int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); }
-  int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); }
-  const llvm::Record *Def;
 // Generate enum class
 void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
                        StringRef Enum, StringRef Prefix,
@@ -231,6 +71,46 @@ void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS,
+// Generate enums for values that clauses can take.
+// Also generate function declarations for get<Enum>Name(StringRef Str).
+void GenerateEnumClauseVal(const std::vector<Record *> &Records,
+                           raw_ostream &OS, DirectiveLanguage &DirLang,
+                           std::string &EnumHelperFuncs) {
+  for (const auto &R : Records) {
+    Clause C{R};
+    const auto &ClauseVals = C.getClauseVals();
+    if (ClauseVals.size() <= 0)
+      continue;
+    const auto &EnumName = C.getEnumName();
+    if (EnumName.size() == 0) {
+      PrintError("enumClauseValue field not set in Clause" +
+                 C.getFormattedName() + ".");
+      return;
+    }
+    OS << "\n";
+    OS << "enum class " << EnumName << " {\n";
+    for (const auto &CV : ClauseVals) {
+      ClauseVal CVal{CV};
+      OS << "  " << CV->getName() << "=" << CVal.getValue() << ",\n";
+    }
+    OS << "};\n";
+    if (DirLang.hasMakeEnumAvailableInNamespace()) {
+      OS << "\n";
+      for (const auto &CV : ClauseVals) {
+        OS << "constexpr auto " << CV->getName() << " = "
+           << "llvm::" << DirLang.getCppNamespace() << "::" << EnumName
+           << "::" << CV->getName() << ";\n";
+      }
+      EnumHelperFuncs += (llvm::Twine(EnumName) + llvm::Twine(" get") +
+                          llvm::Twine(EnumName) + llvm::Twine("(StringRef);\n"))
+                             .str();
+    }
+  }
 // Generate the declaration section for the enumeration in the directive
 // language
 void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -273,6 +153,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   const auto &Clauses = Records.getAllDerivedDefinitions("Clause");
   GenerateEnumClass(Clauses, OS, "Clause", DirLang.getClausePrefix(), DirLang);
+  // Emit ClauseVal enumeration
+  std::string EnumHelperFuncs;
+  GenerateEnumClauseVal(Clauses, OS, DirLang, EnumHelperFuncs);
   // Generic function signatures
   OS << "\n";
   OS << "// Enumeration helper functions\n";
@@ -292,6 +176,10 @@ void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
   OS << "bool isAllowedClauseForDirective(Directive D, "
      << "Clause C, unsigned Version);\n";
   OS << "\n";
+  if (EnumHelperFuncs.length() > 0) {
+    OS << EnumHelperFuncs;
+    OS << "\n";
+  }
   // Closing namespaces
   for (auto Ns : llvm::reverse(Namespaces))
@@ -336,7 +224,7 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
   if (DefaultIt == Records.end()) {
-    PrintError("A least one " + Enum + " must be defined as default.");
+    PrintError("At least one " + Enum + " must be defined as default.");
@@ -361,6 +249,49 @@ void GenerateGetKind(const std::vector<Record *> &Records, raw_ostream &OS,
   OS << "}\n";
+// Generate function implementation for get<ClauseVal>Kind(StringRef Str)
+void GenerateGetKindClauseVal(const std::vector<Record *> &Records,
+                              raw_ostream &OS, StringRef Namespace) {
+  for (const auto &R : Records) {
+    Clause C{R};
+    const auto &ClauseVals = C.getClauseVals();
+    if (ClauseVals.size() <= 0)
+      continue;
+    auto DefaultIt =
+        std::find_if(ClauseVals.begin(), ClauseVals.end(), [](Record *CV) {
+          return CV->getValueAsBit("isDefault") == true;
+        });
+    if (DefaultIt == ClauseVals.end()) {
+      PrintError("At least one val in Clause " + C.getFormattedName() +
+                 " must be defined as default.");
+      return;
+    }
+    const auto DefaultName = (*DefaultIt)->getName();
+    const auto &EnumName = C.getEnumName();
+    if (EnumName.size() == 0) {
+      PrintError("enumClauseValue field not set in Clause" +
+                 C.getFormattedName() + ".");
+      return;
+    }
+    OS << "\n";
+    OS << EnumName << " llvm::" << Namespace << "::get" << EnumName
+       << "(llvm::StringRef Str) {\n";
+    OS << "  return llvm::StringSwitch<" << EnumName << ">(Str)\n";
+    for (const auto &CV : ClauseVals) {
+      ClauseVal CVal{CV};
+      OS << "    .Case(\"" << CVal.getFormattedName() << "\"," << CV->getName()
+         << ")\n";
+    }
+    OS << "    .Default(" << DefaultName << ");\n";
+    OS << "}\n";
+  }
 void GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
                                      raw_ostream &OS, StringRef DirectiveName,
                                      DirectiveLanguage &DirLang,
@@ -672,6 +603,9 @@ void EmitDirectivesImpl(RecordKeeper &Records, raw_ostream &OS) {
   // getClauseName(Clause Kind)
   GenerateGetName(Clauses, OS, "Clause", DirLang, DirLang.getClausePrefix());
+  // get<ClauseVal>Kind(StringRef Str)
+  GenerateGetKindClauseVal(Clauses, OS, DirLang.getCppNamespace());
   // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
   GenerateIsAllowedClause(Directives, OS, DirLang);

diff  --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index 1254bbe8b8fc..dae98bfe8ef9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,7 @@
+mlir_tablegen(OmpCommon.td --gen-directive-decl)
 mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp)
 mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 642282f8af18..eb92745d6fa5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -15,6 +15,7 @@
 #define OPENMP_OPS
 include "mlir/IR/OpBase.td"
+include "mlir/Dialect/OpenMP/OmpCommon.td"
 def OpenMP_Dialect : Dialect {
   let name = "omp";
@@ -42,18 +43,6 @@ def ClauseDefault : StrEnumAttr<
   let cppNamespace = "::mlir::omp";
-// Possible values for the proc_bind clause
-def ClauseProcMaster : StrEnumAttrCase<"master">;
-def ClauseProcClose : StrEnumAttrCase<"close">;
-def ClauseProcSpread : StrEnumAttrCase<"spread">;
-def ClauseProcBind : StrEnumAttr<
-    "ClauseProcBind",
-    "procbind clause",
-    [ClauseProcMaster, ClauseProcClose, ClauseProcSpread]> {
-  let cppNamespace = "::mlir::omp";
 def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
   let summary = "parallel construct";
   let description = [{
@@ -87,7 +76,7 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
-             OptionalAttr<ClauseProcBind>:$proc_bind_val);
+             OptionalAttr<ProcBindKind>:$proc_bind_val);
   let regions = (region AnyRegion:$region);

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b3b7e4c7afa5..215c1910f744 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -408,32 +408,31 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
       blockMapping[&bb] = llvmBB;
-      // Then, convert blocks one by one in topological order to ensure
-      // defs are converted before uses.
-      llvm::SetVector<Block *> blocks = topologicalSort(region);
-      for (auto indexedBB : llvm::enumerate(blocks)) {
-        Block *bb = indexedBB.value();
-        llvm::BasicBlock *curLLVMBB = blockMapping[bb];
-        if (bb->isEntryBlock())
-          codeGenIPBBTI->setSuccessor(0, curLLVMBB);
-        // TODO: Error not returned up the hierarchy
-        if (failed(
-                convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
-          return;
-        // If this block has the terminator then add a jump to
-        // continuation bb
-        for (auto &op : *bb) {
-          if (isa<omp::TerminatorOp>(op)) {
-            builder.SetInsertPoint(curLLVMBB);
-            builder.CreateBr(&continuationIP);
-          }
+    // Then, convert blocks one by one in topological order to ensure
+    // defs are converted before uses.
+    llvm::SetVector<Block *> blocks = topologicalSort(region);
+    for (auto indexedBB : llvm::enumerate(blocks)) {
+      Block *bb = indexedBB.value();
+      llvm::BasicBlock *curLLVMBB = blockMapping[bb];
+      if (bb->isEntryBlock())
+        codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+      // TODO: Error not returned up the hierarchy
+      if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
+        return;
+      // If this block has the terminator then add a jump to
+      // continuation bb
+      for (auto &op : *bb) {
+        if (isa<omp::TerminatorOp>(op)) {
+          builder.SetInsertPoint(curLLVMBB);
+          builder.CreateBr(&continuationIP);
-      // Finally, after all blocks have been traversed and values mapped,
-      // connect the PHI nodes to the results of preceding blocks.
-      connectPHINodes(region, valueMapping, blockMapping);
+    }
+    // Finally, after all blocks have been traversed and values mapped,
+    // connect the PHI nodes to the results of preceding blocks.
+    connectPHINodes(region, valueMapping, blockMapping);
   // TODO: Perform appropriate actions according to the data-sharing
@@ -451,23 +450,24 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
   // called for variables which have destructors/finalizers.
   auto finiCB = [&](InsertPointTy codeGenIP) {};
-  // TODO: The various operands of parallel operation are not handled.
-  // Parallel operation is created with some default options for now.
   llvm::Value *ifCond = nullptr;
   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
     ifCond = valueMapping.lookup(ifExprVar);
   llvm::Value *numThreads = nullptr;
   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
     numThreads = valueMapping.lookup(numThreadsVar);
+  llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
+  if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
+    pbKind = llvm::omp::getProcBindKind(bind.getValue());
   // TODO: Is the Parallel construct cancellable?
   bool isCancellable = false;
   // TODO: Determine the actual alloca insertion point, e.g., the function
   // entry or the alloca insertion point as provided by the body callback
   // above.
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
-  builder.restoreIP(ompBuilder->CreateParallel(
-      builder, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads,
-      llvm::omp::OMP_PROC_BIND_default, isCancellable));
+  builder.restoreIP(
+      ompBuilder->CreateParallel(builder, allocaIP, bodyGenCB, privCB, finiCB,
+                                 ifCond, numThreads, pbKind, isCancellable));
   return success();

diff  --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir
index 60462fee3b97..1cfae359b935 100644
--- a/mlir/test/Target/openmp-llvm.mlir
+++ b/mlir/test/Target/openmp-llvm.mlir
@@ -175,3 +175,34 @@ llvm.func @test_omp_parallel_if_1(%arg0: !llvm.i32) -> () {
 // CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
   // CHECK: call void @__kmpc_barrier
+// CHECK-LABEL: define void @test_omp_parallel_3()
+llvm.func @test_omp_parallel_3() -> () {
+  // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
+  // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_1]], i32 2)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_1:.*]] to {{.*}}
+  omp.parallel proc_bind(master) {
+    omp.barrier
+    omp.terminator
+  }
+  // CHECK: [[OMP_THREAD_3_2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
+  // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_2]], i32 3)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_2:.*]] to {{.*}}
+  omp.parallel proc_bind(close) {
+    omp.barrier
+    omp.terminator
+  }
+  // CHECK: [[OMP_THREAD_3_3:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
+  // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_3]], i32 4)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_3:.*]] to {{.*}}
+  omp.parallel proc_bind(spread) {
+    omp.barrier
+    omp.terminator
+  }
+  llvm.return
+// CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]]
+// CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]]
+// CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]]

diff  --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 5af892438fae..46b9d81115c9 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
+  OpenMPCommonGen.cpp

diff  --git a/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp b/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp
new file mode 100644
index 000000000000..689953587a24
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp
@@ -0,0 +1,73 @@
+//===========- OpenMPCommonGen.cpp - OpenMP common info generator -===========//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// OpenMPCommonGen generates utility information from the single OpenMP source
+// of truth in llvm/lib/Frontend/OpenMP.
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/DirectiveEmitter.h"
+#include "llvm/TableGen/Record.h"
+using llvm::Clause;
+using llvm::ClauseVal;
+using llvm::raw_ostream;
+using llvm::RecordKeeper;
+using llvm::Twine;
+static bool emitDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause");
+  for (const auto &r : clauses) {
+    Clause c{r};
+    const auto &clauseVals = c.getClauseVals();
+    if (clauseVals.size() <= 0)
+      continue;
+    const auto enumName = c.getEnumName();
+    assert(enumName.size() != 0 && "enumClauseValue field not set.");
+    std::vector<std::string> cvDefs;
+    for (const auto &cv : clauseVals) {
+      ClauseVal cval{cv};
+      if (!cval.isUserVisible())
+        continue;
+      const auto name = cval.getFormattedName();
+      std::string cvDef{(enumName + llvm::Twine(name)).str()};
+      os << "def " << cvDef << " : StrEnumAttrCase<\"" << name << "\">;\n";
+      cvDefs.push_back(cvDef);
+    }
+    os << "def " << enumName << ": StrEnumAttr<\n";
+    os << "  \"Clause" << enumName << "\",\n";
+    os << "  \"" << enumName << " Clause\",\n";
+    os << "  [";
+    for (unsigned int i = 0; i < cvDefs.size(); i++) {
+      os << cvDefs[i];
+      if (i != cvDefs.size() - 1)
+        os << ",";
+    }
+    os << "]> {\n";
+    os << "    let cppNamespace = \"::mlir::omp\";\n";
+    os << "}\n";
+  }
+  return false;
+// Registers the generator to mlir-tblgen.
+static mlir::GenRegistration
+    genDirectiveDecls("gen-directive-decl",
+                      "Generate declarations for directives (OpenMP etc.)",
+                      [](const RecordKeeper &records, raw_ostream &os) {
+                        return emitDecls(records, os);
+                      });


More information about the Mlir-commits mailing list