[Mlir-commits] [mlir] e0c3b94 - [mlir] Restrict dialect doc gen to a single dialect

River Riddle llvmlistbot at llvm.org
Mon May 16 15:35:18 PDT 2022


Author: River Riddle
Date: 2022-05-16T15:35:07-07:00
New Revision: e0c3b94c80143376473ec7110ca0c8a4fe03112e

URL: https://github.com/llvm/llvm-project/commit/e0c3b94c80143376473ec7110ca0c8a4fe03112e
DIFF: https://github.com/llvm/llvm-project/commit/e0c3b94c80143376473ec7110ca0c8a4fe03112e.diff

LOG: [mlir] Restrict dialect doc gen to a single dialect

In the overwhelmingly majority of cases only one dialect is generated at a time
anyways, and this restriction more easily catches user error when multiple
dialects might be generated. We hit this semi-recently with the PDL dialect,
and circt+other downstream users are also actively hitting this as well.

Differential Revision: https://reviews.llvm.org/D125651

Added: 
    mlir/tools/mlir-tblgen/DialectGenUtilities.h

Modified: 
    mlir/include/mlir/Dialect/AMX/CMakeLists.txt
    mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
    mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
    mlir/test/mlir-tblgen/gen-dialect-doc.td
    mlir/tools/mlir-tblgen/DialectGen.cpp
    mlir/tools/mlir-tblgen/OpDocGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index ae9b201dae75..f3f1aff5a636 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect(AMX amx)
-add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc)
+add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 
 set(LLVM_TARGET_DEFINITIONS AMX.td)
 mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 143497c642f4..1c679bcd049b 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect(ArmNeon arm_neon)
-add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc)
+add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
 
 set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
 mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 4ddd619311cc..06595b7088a1 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect(ArmSVE arm_sve ArmSVE)
-add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc)
+add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc -dialect=arm_sve)
 
 set(LLVM_TARGET_DEFINITIONS ArmSVE.td)
 mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions)

diff  --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td
index 1eda916e814c..02640f531d4e 100644
--- a/mlir/test/mlir-tblgen/gen-dialect-doc.td
+++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td
@@ -1,4 +1,5 @@
-// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include %s | FileCheck %s
+// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include -dialect=test %s | FileCheck %s
+// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include -dialect=test_toc %s | FileCheck %s --check-prefix=CHECK_TOC
 
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -55,6 +56,6 @@ def Toc_Dialect : Dialect {
 }
 def BOp : Op<Toc_Dialect, "b", []>;
 
-// CHECK: Dialect with
-// CHECK: [TOC]
-// CHECK: here.
+// CHECK_TOC: Dialect with
+// CHECK_TOC: [TOC]
+// CHECK_TOC: here.

diff  --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 347f08f36f81..8944e8344187 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "DialectGenUtilities.h"
 #include "mlir/TableGen/Class.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
 #include "mlir/TableGen/Format.h"
@@ -55,28 +56,30 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
           DialectFilterIterator(records.end(), records.end(), filterFn)};
 }
 
-static Optional<Dialect>
-findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) {
-  // Select the dialect to gen for.
-  if (dialectDefs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
-    return Dialect(dialectDefs.front());
+Optional<Dialect> tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
+  if (dialects.empty()) {
+    llvm::errs() << "no dialect was found\n";
+    return llvm::None;
   }
 
+  // Select the dialect to gen for.
+  if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
+    return dialects.front();
+
   if (selectedDialect.getNumOccurrences() == 0) {
     llvm::errs() << "when more than 1 dialect is present, one must be selected "
                     "via '-dialect'\n";
     return llvm::None;
   }
 
-  const auto *dialectIt =
-      llvm::find_if(dialectDefs, [](const llvm::Record *def) {
-        return Dialect(def).getName() == selectedDialect;
-      });
-  if (dialectIt == dialectDefs.end()) {
+  const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
+    return dialect.getName() == selectedDialect;
+  });
+  if (dialectIt == dialects.end()) {
     llvm::errs() << "selected dialect with '-dialect' does not exist\n";
     return llvm::None;
   }
-  return Dialect(*dialectIt);
+  return *dialectIt;
 }
 
 //===----------------------------------------------------------------------===//
@@ -235,7 +238,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
   if (dialectDefs.empty())
     return false;
 
-  Optional<Dialect> dialect = findSelectedDialect(dialectDefs);
+  SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+  Optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
   auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
@@ -308,7 +312,8 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
   if (dialectDefs.empty())
     return false;
 
-  Optional<Dialect> dialect = findSelectedDialect(dialectDefs);
+  SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+  Optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
   emitDialectDef(*dialect, os);

diff  --git a/mlir/tools/mlir-tblgen/DialectGenUtilities.h b/mlir/tools/mlir-tblgen/DialectGenUtilities.h
new file mode 100644
index 000000000000..80fed9626deb
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectGenUtilities.h
@@ -0,0 +1,24 @@
+//===- DialectGenUtilities.h - Utilities for dialect generation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_
+#define MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace tblgen {
+class Dialect;
+
+/// Find the dialect selected by the user to generate for. Returns None if no
+/// dialect was found, or if more than one potential dialect was found.
+Optional<Dialect> findDialectToGenerate(ArrayRef<Dialect> dialects);
+} // namespace tblgen
+} // namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_

diff  --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 83229cda37de..8d66448da0e9 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "DialectGenUtilities.h"
 #include "DocGenUtilities.h"
 #include "OpGenHelpers.h"
 #include "mlir/Support/IndentedOstream.h"
@@ -18,6 +19,7 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -35,8 +37,6 @@ using namespace mlir::tblgen;
 
 using mlir::tblgen::Operator;
 
-extern llvm::cl::opt<std::string> selectedDialect;
-
 // Emit the description by aligning the text to the left per line (e.g.,
 // removing the minimum indentation across the block).
 //
@@ -307,9 +307,6 @@ static void emitDialectDoc(const Dialect &dialect,
                            ArrayRef<AttrDef> attrDefs, ArrayRef<Operator> ops,
                            ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
                            raw_ostream &os) {
-  if (selectedDialect.getNumOccurrences() &&
-      dialect.getName() != selectedDialect)
-    return;
   os << "# '" << dialect.getName() << "' Dialect\n\n";
   emitIfNotEmpty(dialect.getSummary(), os);
   emitIfNotEmpty(dialect.getDescription(), os);
@@ -351,7 +348,7 @@ static void emitDialectDoc(const Dialect &dialect,
   }
 }
 
-static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
   std::vector<Record *> attrDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr");
@@ -362,7 +359,8 @@ static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   std::vector<Record *> attrDefDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");
 
-  std::set<Dialect> dialectsWithDocs;
+  llvm::SetVector<Dialect, SmallVector<Dialect>, std::set<Dialect>>
+      dialectsWithDocs;
 
   llvm::StringMap<std::vector<Attribute>> dialectAttrs;
   llvm::StringMap<std::vector<AttrDef>> dialectAttrDefs;
@@ -399,13 +397,17 @@ static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
     dialectsWithDocs.insert(type.getDialect());
   }
 
+  Optional<Dialect> dialect =
+      findDialectToGenerate(dialectsWithDocs.getArrayRef());
+  if (!dialect)
+    return true;
+
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (const Dialect &dialect : dialectsWithDocs) {
-    StringRef dialectName = dialect.getName();
-    emitDialectDoc(dialect, dialectAttrs[dialectName],
-                   dialectAttrDefs[dialectName], dialectOps[dialectName],
-                   dialectTypes[dialectName], dialectTypeDefs[dialectName], os);
-  }
+  StringRef dialectName = dialect->getName();
+  emitDialectDoc(*dialect, dialectAttrs[dialectName],
+                 dialectAttrDefs[dialectName], dialectOps[dialectName],
+                 dialectTypes[dialectName], dialectTypeDefs[dialectName], os);
+  return false;
 }
 
 //===----------------------------------------------------------------------===//
@@ -437,6 +439,5 @@ static mlir::GenRegistration
 static mlir::GenRegistration
     genRegister("gen-dialect-doc", "Generate dialect documentation",
                 [](const RecordKeeper &records, raw_ostream &os) {
-                  emitDialectDoc(records, os);
-                  return false;
+                  return emitDialectDoc(records, os);
                 });


        


More information about the Mlir-commits mailing list