[Mlir-commits] [mlir] Add support for enum doc gen (PR #98885)

Tom Natan llvmlistbot at llvm.org
Mon Jul 15 11:24:53 PDT 2024


https://github.com/tomnatan30 updated https://github.com/llvm/llvm-project/pull/98885

>From 41c2c0c0c8915c5530e0aead04ff9c3e3befa4d8 Mon Sep 17 00:00:00 2001
From: tomnatan30 <tomnatan at google.com>
Date: Mon, 15 Jul 2024 11:13:15 +0000
Subject: [PATCH 1/3] add support for enum doc gen

---
 mlir/test/mlir-tblgen/gen-dialect-doc.td | 21 +++++++
 mlir/tools/mlir-tblgen/OpDocGen.cpp      | 73 +++++++++++++++++++++---
 2 files changed, 87 insertions(+), 7 deletions(-)

diff --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td
index c9492eb9ac3ce..79d755111e8f6 100644
--- a/mlir/test/mlir-tblgen/gen-dialect-doc.td
+++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td
@@ -3,6 +3,7 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def Test_Dialect : Dialect {
@@ -69,6 +70,16 @@ def TestTypeDefParams : TypeDef<Test_Dialect, "TestTypeDefParams"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+def TestEnum :
+    I32EnumAttr<"TestEnum",
+        "enum summary", [
+        I32EnumAttrCase<"First", 0, "first">,
+        I32EnumAttrCase<"Second", 1, "second">,
+        I32EnumAttrCase<"Third", 2, "third">]> {
+  let genSpecializedAttr = 1;
+  let cppNamespace = "NS";
+}
+
 // CHECK: Dialect without a [TOC] here.
 // CHECK: TOC added by tool.
 // CHECK: [TOC]
@@ -109,6 +120,16 @@ def TestTypeDefParams : TypeDef<Test_Dialect, "TestTypeDefParams"> {
 // CHECK: Syntax:
 // CHECK: !test.test_type_def_params
 
+// CHECK: ## Enums
+// CHECK: ### TestEnum
+// CHECK: enum summary
+// CHECK: #### Cases:
+// CHECK: | Symbol | Value | String |
+// CHECK: | :----: | :---: | ------ |
+// CHECK: | First | `0` | first |
+// CHECK: | Second | `1` | second |
+// CHECK: | Third | `2` | third |
+
 def Toc_Dialect : Dialect {
   let name = "test_toc";
   let summary = "Dialect of ops to test";
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 7cd2690ea8155..d55414d7b95f8 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -16,6 +16,7 @@
 #include "OpGenHelpers.h"
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
@@ -37,7 +38,7 @@
 // Commandline Options
 //===----------------------------------------------------------------------===//
 static llvm::cl::OptionCategory
-    docCat("Options for -gen-(attrdef|typedef|op|dialect)-doc");
+    docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
 llvm::cl::opt<std::string>
     stripPrefix("strip-prefix",
                 llvm::cl::desc("Strip prefix of the fully qualified names"),
@@ -381,6 +382,38 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
     emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
 }
 
+//===----------------------------------------------------------------------===//
+// Enum Documentation
+//===----------------------------------------------------------------------===//
+
+static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
+  os << llvm::formatv("### {0}\n", def.getEnumClassName());
+
+  // Emit the summary if present.
+  if (!def.getSummary().empty())
+    os << "\n" << def.getSummary() << "\n";
+
+  // Emit case documentation.
+  std::vector<EnumAttrCase> cases = def.getAllCases();
+  os << "\n#### Cases:\n\n";
+  os << "| Symbol | Value | String |\n"
+     << "| :----: | :---: | ------ |\n";
+  for (const auto &it : cases) {
+    os << "| " << it.getSymbol() << " | `" << it.getValue() << "` | "
+        << it.getStr() << " |\n";
+  }
+
+  os << "\n";
+}
+
+static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  std::vector<llvm::Record *> defs =
+      recordKeeper.getAllDerivedDefinitions("EnumAttr");
+
+  os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
+  for (const llvm::Record *def : defs) emitEnumDoc(EnumAttr(def), os);
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect Documentation
 //===----------------------------------------------------------------------===//
@@ -413,7 +446,7 @@ static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
 static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
                       ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                       ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                      raw_ostream &os) {
+                      ArrayRef<EnumAttr> enums, raw_ostream &os) {
   if (!ops.empty()) {
     os << "## Operations\n\n";
     emitSourceLink(inputFilename, os);
@@ -459,13 +492,19 @@ static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
     for (const TypeDef &def : typeDefs)
       emitAttrOrTypeDefDoc(def, os);
   }
+
+  if (!enums.empty()) {
+    os << "## Enums\n\n";
+    for (const EnumAttr &def : enums)
+      emitEnumDoc(def, os);
+  }
 }
 
 static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
                            ArrayRef<Attribute> attributes,
                            ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                            ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                           raw_ostream &os) {
+                           ArrayRef<EnumAttr> enums, raw_ostream &os) {
   os << "# '" << dialect.getName() << "' Dialect\n\n";
   emitIfNotEmpty(dialect.getSummary(), os);
   emitIfNotEmpty(dialect.getDescription(), os);
@@ -475,7 +514,8 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
   if (!r.match(dialect.getDescription()))
     os << "[TOC]\n\n";
 
-  emitBlock(attributes, inputFilename, attrDefs, ops, types, typeDefs, os);
+  emitBlock(attributes, inputFilename, attrDefs, ops, types, typeDefs, enums,
+            os);
 }
 
 static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
@@ -495,21 +535,30 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
       recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef");
   std::vector<Record *> attrDefDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");
+  std::vector<Record *> enumDefs =
+      recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
 
   std::vector<Attribute> dialectAttrs;
   std::vector<AttrDef> dialectAttrDefs;
   std::vector<OpDocGroup> dialectOps;
   std::vector<Type> dialectTypes;
   std::vector<TypeDef> dialectTypeDefs;
+  std::vector<EnumAttr> dialectEnums;
 
   llvm::SmallDenseSet<Record *> seen;
-  auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
-    if (seen.insert(record).second && def.getDialect() == *dialect) {
+  auto addIfNotSeen = [&](llvm::Record *record, const auto &def, auto &vec) {
+    if (seen.insert(record).second) {
       vec.push_back(def);
       return true;
     }
     return false;
   };
+  auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
+    if (def.getDialect() == *dialect) {
+      return addIfNotSeen(record, def, vec);
+    }
+    return false;
+  };
 
   SmallDenseMap<Record *, OpDocGroup> opDocGroup;
 
@@ -539,6 +588,9 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
     addIfInDialect(def, TypeDef(def), dialectTypeDefs);
   for (Record *def : typeDefs)
     addIfInDialect(def, Type(def), dialectTypes);
+  dialectEnums.reserve(enumDefs.size());
+  for (Record *def : enumDefs)
+    addIfNotSeen(def, EnumAttr(def), dialectEnums);
 
   // Sort alphabetically ignorning dialect for ops and section name for
   // sections.
@@ -557,7 +609,7 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
   emitDialectDoc(*dialect, recordKeeper.getInputFilename(), dialectAttrs,
                  dialectAttrDefs, dialectOps, dialectTypes, dialectTypeDefs,
-                 os);
+                 dialectEnums, os);
   return false;
 }
 
@@ -587,6 +639,13 @@ static mlir::GenRegistration
                       return false;
                     });
 
+static mlir::GenRegistration
+    genEnumRegister("gen-enum-doc", "Generate dialect enum documentation",
+                    [](const RecordKeeper &records, raw_ostream &os) {
+                      emitEnumDoc(records, os);
+                      return false;
+                    });
+
 static mlir::GenRegistration
     genRegister("gen-dialect-doc", "Generate dialect documentation",
                 [](const RecordKeeper &records, raw_ostream &os) {

>From 2d36a479a76bd479450161c141a4c56ef79db9db Mon Sep 17 00:00:00 2001
From: tomnatan30 <tomnatan at google.com>
Date: Mon, 15 Jul 2024 12:41:09 +0000
Subject: [PATCH 2/3] fix clang format

---
 mlir/tools/mlir-tblgen/OpDocGen.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index d55414d7b95f8..cb3dc5b45d00c 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -229,8 +229,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
         // Expandable description.
         // This appears as just the summary, but when clicked shows the full
         // description.
-        os << "<details>"
-           << "<summary>" << it.attr.getSummary() << "</summary>"
+        os << "<details>" << "<summary>" << it.attr.getSummary() << "</summary>"
            << "{{% markdown %}}" << description << "{{% /markdown %}}"
            << "</details>";
       } else {
@@ -400,7 +399,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
      << "| :----: | :---: | ------ |\n";
   for (const auto &it : cases) {
     os << "| " << it.getSymbol() << " | `" << it.getValue() << "` | "
-        << it.getStr() << " |\n";
+       << it.getStr() << " |\n";
   }
 
   os << "\n";
@@ -411,7 +410,8 @@ static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
       recordKeeper.getAllDerivedDefinitions("EnumAttr");
 
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (const llvm::Record *def : defs) emitEnumDoc(EnumAttr(def), os);
+  for (const llvm::Record *def : defs)
+    emitEnumDoc(EnumAttr(def), os);
 }
 
 //===----------------------------------------------------------------------===//

>From 1dc69162197e297b6dcbf8e94b833c065c6cb94d Mon Sep 17 00:00:00 2001
From: tomnatan30 <tomnatan at google.com>
Date: Mon, 15 Jul 2024 18:24:39 +0000
Subject: [PATCH 3/3] resolve review comments

---
 mlir/tools/mlir-tblgen/OpDocGen.cpp | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index cb3dc5b45d00c..11f5e18b61de1 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -394,12 +394,14 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
 
   // Emit case documentation.
   std::vector<EnumAttrCase> cases = def.getAllCases();
-  os << "\n#### Cases:\n\n";
-  os << "| Symbol | Value | String |\n"
-     << "| :----: | :---: | ------ |\n";
-  for (const auto &it : cases) {
-    os << "| " << it.getSymbol() << " | `" << it.getValue() << "` | "
-       << it.getStr() << " |\n";
+  if (!cases.empty()) {
+    os << "\n#### Cases:\n\n";
+    os << "| Symbol | Value | String |\n"
+       << "| :----: | :---: | ------ |\n";
+    for (const auto &it : cases) {
+      os << "| " << it.getSymbol() << " | `" << it.getValue() << "` | "
+         << it.getStr() << " |\n";
+    }
   }
 
   os << "\n";
@@ -554,10 +556,7 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
     return false;
   };
   auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
-    if (def.getDialect() == *dialect) {
-      return addIfNotSeen(record, def, vec);
-    }
-    return false;
+    return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
   };
 
   SmallDenseMap<Record *, OpDocGroup> opDocGroup;



More information about the Mlir-commits mailing list