[Mlir-commits] [mlir] [mlir][ODS] Add namespace filtering to `-gen-enum-*` (PR #89627)

Markus Böck llvmlistbot at llvm.org
Mon Apr 22 09:39:47 PDT 2024


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/89627

Unlike all other ODS generators, there is currently no way to filter what enums TableGen should be generating. This is problematic if a dialect depends on another dialect and therefore transitively `#include`s another dialects enums: The declarations and definitions of the enum will be generated in both dialects and therefore lead to likely a compiler error and a guaranteed linker error.

This PR therefore adds a new command line flag called `-enum-namespace`, inspired by the existing `-(attr|type)-dialect` flags, to restrict the set of enums that should be generated. Unlike attributes and types, enums are not part of a dialect making the same filtering mechanism not possible. As an alternative, filtering via the C++ namespace was implemented instead.

>From 947a09f495301cbba748d0a7a220d86292f6a390 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Mon, 22 Apr 2024 17:38:57 +0100
Subject: [PATCH] [mlir][ODS] Add namespace filtering to `-gen-enum-*`

Unlike all other ODS generators, there is currently no way to filter what enums TableGen should be generating. This is problematic if a dialect depends on another dialect and therefore transitively `#include`s another dialects enums: The declarations and definitions of the enum will be generated in both dialects and therefore lead to likely a compiler error and a guaranteed linker error.

This PR therefore adds a new command line flag called `-enum-namespace`, inspired by the existing `-(attr|type)-dialect` flags, to restrict the set of enums that should be generated. Unlike attributes and types, enums are not part of a dialect making the same filtering mechanism not possible. As an alternative, filtering via the C++ namespace was implemented instead.
---
 .../mlir-tblgen/enums-gen-namespace-filter.td | 30 ++++++++++
 mlir/tools/mlir-tblgen/EnumsGen.cpp           | 60 ++++++++++++++-----
 2 files changed, 76 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/enums-gen-namespace-filter.td

diff --git a/mlir/test/mlir-tblgen/enums-gen-namespace-filter.td b/mlir/test/mlir-tblgen/enums-gen-namespace-filter.td
new file mode 100644
index 00000000000000..e065745a4c2f02
--- /dev/null
+++ b/mlir/test/mlir-tblgen/enums-gen-namespace-filter.td
@@ -0,0 +1,30 @@
+// RUN: mlir-tblgen -gen-enum-decls -enum-namespace=test::ns1 -I %S/../../include %s | FileCheck %s --check-prefix=NS1
+// RUN: mlir-tblgen -gen-enum-decls -enum-namespace=test::ns2 -I %S/../../include %s | FileCheck %s --check-prefix=NS2
+// RUN: mlir-tblgen -gen-enum-decls -enum-namespace=test::ns -I %S/../../include %s | FileCheck %s --check-prefix=NS
+// RUN: mlir-tblgen -gen-enum-decls -enum-namespace=test -I %S/../../include %s | FileCheck %s --check-prefix=TEST-NS
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpBase.td"
+
+def EnumNS1 : I32BitEnumAttr<"EnumNS1", "", []> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test::ns1";
+}
+
+def EnumNS2 : I32BitEnumAttr<"EnumNS2", "", []> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test::ns2";
+}
+
+// NS1-NOT: enum class EnumNS2
+// NS1: enum class EnumNS1
+// NS1-NOT: enum class EnumNS2
+
+// NS2-NOT: enum class EnumNS1
+// NS2: enum class EnumNS2
+
+// NS-NOT: enum class EnumNS1
+// NS-NOT: enum class EnumNS2
+
+// TEST-NS: enum class EnumNS1
+// TEST-NS: enum class EnumNS2
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index f1d7a233b66a9a..78102f5c497195 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -11,9 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "FormatGen.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
@@ -23,18 +25,32 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+using namespace mlir;
+using namespace mlir::tblgen;
+
 using llvm::formatv;
 using llvm::isDigit;
 using llvm::PrintFatalError;
 using llvm::raw_ostream;
 using llvm::Record;
 using llvm::RecordKeeper;
-using llvm::StringRef;
-using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
-using mlir::tblgen::FmtContext;
-using mlir::tblgen::tgfmt;
+
+/// Returns true if 'subNamespace' is a sub-namespace of 'parentNamespace'.
+/// I.e. 'subNamespace' is contained within 'parentNamespace'.
+static bool isSubNamespace(ArrayRef<StringRef> subNamespace,
+                           StringRef parentNamespace) {
+  SmallVector<StringRef> parentNamespaces;
+  llvm::SplitString(parentNamespace, parentNamespaces, "::");
+  // If the parent namespace has more components than the sub-namespace it
+  // cannot possibly be a parent namespace.
+  if (parentNamespaces.size() > subNamespace.size())
+    return false;
+
+  // Otherwise, make sure all components of the parent namespace match in the
+  // sub-namespace.
+  return llvm::equal(parentNamespaces,
+                     subNamespace.take_front(parentNamespaces.size()));
+}
 
 static std::string makeIdentifier(StringRef str) {
   if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
@@ -553,7 +569,8 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
   os << "}\n";
 }
 
-static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+static void emitEnumDecl(const Record &enumDef, raw_ostream &os,
+                         StringRef namespacePrefix) {
   EnumAttr enumAttr(enumDef);
   StringRef enumName = enumAttr.getEnumClassName();
   StringRef cppNamespace = enumAttr.getCppNamespace();
@@ -568,6 +585,9 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
 
+  if (!isSubNamespace(namespaces, namespacePrefix))
+    return;
+
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
@@ -642,23 +662,28 @@ class {1} : public ::mlir::{2} {
   emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
 }
 
-static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os,
+                          StringRef namespacePrefix) {
   llvm::emitSourceFileHeader("Enum Utility Declarations", os, recordKeeper);
 
   auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
   for (const auto *def : defs)
-    emitEnumDecl(*def, os);
+    emitEnumDecl(*def, os, namespacePrefix);
 
   return false;
 }
 
-static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+static void emitEnumDef(const Record &enumDef, raw_ostream &os,
+                        StringRef namespacePrefix) {
   EnumAttr enumAttr(enumDef);
   StringRef cppNamespace = enumAttr.getCppNamespace();
 
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
 
+  if (!isSubNamespace(namespaces, namespacePrefix))
+    return;
+
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
@@ -680,26 +705,33 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   os << "\n";
 }
 
-static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os,
+                         StringRef namespacePrefix) {
   llvm::emitSourceFileHeader("Enum Utility Definitions", os, recordKeeper);
 
   auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
   for (const auto *def : defs)
-    emitEnumDef(*def, os);
+    emitEnumDef(*def, os, namespacePrefix);
 
   return false;
 }
 
+static llvm::cl::OptionCategory enumGenCat("Options for -gen-enum-*");
+static llvm::cl::opt<std::string>
+    enumNamespace("enum-namespace",
+                  llvm::cl::desc("Generate enums within this namespace"),
+                  llvm::cl::cat(enumGenCat));
+
 // Registers the enum utility generator to mlir-tblgen.
 static mlir::GenRegistration
     genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
                  [](const RecordKeeper &records, raw_ostream &os) {
-                   return emitEnumDecls(records, os);
+                   return emitEnumDecls(records, os, enumNamespace);
                  });
 
 // Registers the enum utility generator to mlir-tblgen.
 static mlir::GenRegistration
     genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
                 [](const RecordKeeper &records, raw_ostream &os) {
-                  return emitEnumDefs(records, os);
+                  return emitEnumDefs(records, os, enumNamespace);
                 });



More information about the Mlir-commits mailing list