[Mlir-commits] [mlir] [MLIR][TblGen] add AttrOrTypeCAPIGen (PR #172590)

Maksim Levental llvmlistbot at llvm.org
Mon Dec 22 16:49:24 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172590

>From 2d2288b6e0b8d320c4e6be2e9a5cd401b9b70d99 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 16 Dec 2025 19:41:23 -0800
Subject: [PATCH] [MLIR][TblGen] add AttrOrTypeCAPIGen

---
 mlir/include/mlir/IR/EnumAttr.td             |   1 +
 mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp | 248 +++++++++++++++++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp  |  42 +---
 mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h |  44 ++++
 mlir/tools/mlir-tblgen/CMakeLists.txt        |   1 +
 5 files changed, 297 insertions(+), 39 deletions(-)
 create mode 100644 mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 4bc94809ed3ce..87b2d793a304e 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -493,6 +493,7 @@ class EnumParameter<EnumInfo enumInfo>
     !cast<EnumAttrInfo>(enumInfo).parameterParser, ?);
   let printer = !if(!isa<EnumAttrInfo>(enumInfo),
     !cast<EnumAttrInfo>(enumInfo).parameterPrinter, ?);
+  string underlyingEnumName = enumInfo.className;
 }
 
 // An attribute backed by a C++ enum. The attribute contains a single
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp
new file mode 100644
index 0000000000000..89720c38e6053
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeCAPIGen.cpp
@@ -0,0 +1,248 @@
+//===- AttrOrTypeCAPIGen.cpp - MLIR Attribute and Type CAPI generation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "AttrOrTypeFormatGen.h"
+#include "CppGenUtilities.h"
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "mlir/TableGen/Pass.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-attrortypecapigen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+using llvm::formatv;
+using llvm::Record;
+using llvm::RecordKeeper;
+
+static llvm::cl::OptionCategory attrOrTypeCAPIDefGenCat(
+    "Options for -gen-attr-capi-* and -gen-typedef-capi-*");
+static llvm::cl::opt<std::string>
+    capiDialect("attr-or-type-capi-dialect",
+                llvm::cl::desc("Generate C APIs for this dialect"),
+                llvm::cl::cat(attrOrTypeCAPIDefGenCat),
+                llvm::cl::CommaSeparated);
+static llvm::cl::opt<std::string> capiNamespacePrefix(
+    "attr-or-type-capi-namespace-prefix",
+    llvm::cl::desc("Generate C APIs with this namespace prefix"),
+    llvm::cl::cat(attrOrTypeCAPIDefGenCat));
+
+static std::string makeIdentifier(StringRef str) {
+  if (!str.empty() && llvm::isDigit(static_cast<unsigned char>(str.front()))) {
+    std::string newStr = std::string("_") + str.str();
+    return newStr;
+  }
+  return str.str();
+}
+
+static std::string withCapitalFirstLetter(std::string name) {
+  name[0] = static_cast<std::string::value_type>(
+      std::toupper(static_cast<unsigned char>(name[0])));
+  return name;
+}
+
+static StringRef namespacePrefix() {
+  static const std::string prefix = [] {
+    if (!capiNamespacePrefix.empty())
+      return "mlir" + capiNamespacePrefix;
+    return "mlir" + withCapitalFirstLetter(capiDialect.getValue());
+  }();
+  return prefix;
+}
+
+static void emitEnums(const RecordKeeper &records, raw_ostream &os) {
+  for (const Record *it :
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
+    EnumInfo enumInfo(*it);
+    os << "// " << enumInfo.getSummary() << "\n";
+    os << "enum " << namespacePrefix() << enumInfo.getEnumClassName();
+
+    if (!enumInfo.getUnderlyingType().empty())
+      os << " : " << enumInfo.getUnderlyingType();
+    os << " {\n";
+
+    for (const EnumCase &enumerant : enumInfo.getAllCases()) {
+      auto symbol = makeIdentifier(enumerant.getSymbol());
+      auto value = enumerant.getValue();
+      if (value >= 0)
+        os << formatv("  {0} = {1},\n", symbol, value);
+      else
+        os << formatv("  {0},\n", symbol);
+    }
+    os << "};\n\n";
+  }
+}
+
+namespace {
+struct CAPIDefGenerator : DefGenerator {
+  CAPIDefGenerator(const RecordKeeper &records, StringRef className,
+                   raw_ostream &os, const StringRef &defType,
+                   const StringRef &valueType, bool isAttrGenerator)
+      : DefGenerator(records.getAllDerivedDefinitionsIfDefined(className), os,
+                     defType, valueType, isAttrGenerator),
+        records(records) {}
+
+  bool emitDecls(StringRef selectedDialect) override;
+  const RecordKeeper &records;
+};
+} // namespace
+
+static llvm::Twine mapParamTypeToCAPI(const AttrOrTypeParameter &param) {
+  if (const llvm::DefInit *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
+    if (defInit->getDef()->isSubClassOf("EnumParameter"))
+      return namespacePrefix() +
+             defInit->getDef()->getValueAsString("underlyingEnumName");
+  }
+  StringRef cppType = param.getCppType();
+  if (cppType == "Type")
+    return "MlirType";
+  return cppType;
+}
+
+static void emitGettorDecl(const AttrOrTypeDef &def, raw_ostream &os,
+                           bool isAttrGenerator) {
+  os << "MLIR_CAPI_EXPORTED ";
+  if (isAttrGenerator)
+    os << "MlirAttribute ";
+  else
+    os << "MlirType ";
+  os << namespacePrefix() << def.getCppClassName() << "Get(MlirContext context";
+  ArrayRef<AttrOrTypeParameter> params = def.getParameters();
+  if (!params.empty())
+    os << ", ";
+  for (auto [i, param] : llvm::enumerate(params)) {
+    os << mapParamTypeToCAPI(param) << " " << param.getName()
+       << (i < params.size() - 1 ? ", " : "");
+  }
+  os << ");\n";
+}
+
+static void emitAccessorDecls(const AttrOrTypeDef &def, raw_ostream &os,
+                              bool isAttrGenerator) {
+  ArrayRef<AttrOrTypeParameter> params = def.getParameters();
+  if (params.empty())
+    return;
+  for (auto param : params) {
+    os << "MLIR_CAPI_EXPORTED ";
+    std::string paramName = param.getName().str();
+    os << mapParamTypeToCAPI(param) << " " << namespacePrefix()
+       << def.getCppClassName() << "Get"
+       << withCapitalFirstLetter(param.getName().str());
+    if (isAttrGenerator)
+      os << "(MlirAttribute attr)\n";
+    else
+      os << "(MlirType type)\n";
+  }
+}
+
+static void emitTypeIDDecl(const AttrOrTypeDef &def, raw_ostream &os) {
+  os << "MLIR_CAPI_EXPORTED MlirTypeID " << namespacePrefix()
+     << def.getCppClassName() << "GetTypeID();\n";
+}
+
+static void emitThingHeader(const AttrOrTypeDef &def, raw_ostream &os) {
+  const char *const header = R"(
+//===----------------------------------------------------------------------===//
+// {0}
+//===----------------------------------------------------------------------===//
+
+)";
+  os << formatv(header, def.getCppClassName());
+}
+
+bool CAPIDefGenerator::emitDecls(StringRef selectedDialect) {
+  emitSourceFileHeader((defType + "Def C API Def Declarations").str(), os);
+
+  emitEnums(records, os);
+
+  SmallVector<AttrOrTypeDef, 16> defs;
+  collectAllDefs(selectedDialect, defRecords, defs);
+  if (defs.empty())
+    return false;
+
+  for (const AttrOrTypeDef &def : defs) {
+    emitThingHeader(def, os);
+    emitGettorDecl(def, os, isAttrGenerator);
+    emitTypeIDDecl(def, os);
+    if (def.genAccessors())
+      emitAccessorDecls(def, os, isAttrGenerator);
+  }
+
+  os << "\n";
+
+  return false;
+}
+
+namespace {
+/// A specialized generator for AttrDefs.
+struct CAPIAttrDefGenerator : public CAPIDefGenerator {
+  CAPIAttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
+      : CAPIDefGenerator(records, "AttrDef", os, "Attr", "Attribute",
+                         /*isAttrGenerator=*/true) {}
+};
+/// A specialized generator for TypeDefs.
+struct CAPITypeDefGenerator : public CAPIDefGenerator {
+  CAPITypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
+      : CAPIDefGenerator(records, "TypeDef", os, "Type", "Type",
+                         /*isAttrGenerator=*/false) {}
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Registration hooks
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// AttrDef
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genAttrDecls("gen-attrdef-capi-decls",
+                 "Generate AttrDef C API declarations",
+                 [](const RecordKeeper &records, raw_ostream &os) {
+                   CAPIAttrDefGenerator generator(records, os);
+                   return generator.emitDecls(capiDialect);
+                 });
+
+// static mlir::GenRegistration
+//     genAttrDefs("gen-attrdef-capi-defs", "Generate AttrDef C API
+//     definitions",
+//                 [](const RecordKeeper &records, raw_ostream &os) {
+//                   CAPIAttrDefGenerator generator(records, os);
+//                   return generator.emitDefs(attrDialect);
+//                 });
+
+//===----------------------------------------------------------------------===//
+// TypeDef
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genTypeDecls("gen-typedef-capi-decls",
+                 "Generate TypeDef C API declarations",
+                 [](const RecordKeeper &records, raw_ostream &os) {
+                   CAPITypeDefGenerator generator(records, os);
+                   return generator.emitDecls(capiDialect);
+                 });
+
+// static mlir::GenRegistration
+//     genTypeDefs("gen-typedef-capi-defs", "Generate TypeDef C API
+//     definitions",
+//                 [](const RecordKeeper &records, raw_ostream &os) {
+//                   CAPITypeDefGenerator generator(records, os);
+//                   return generator.emitDefs(capiDialect);
+//                 });
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..c828dc6f67746 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -32,9 +32,9 @@ using llvm::RecordKeeper;
 
 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
 /// specified and can only find one dialect's defs, use that.
-static void collectAllDefs(StringRef selectedDialect,
-                           ArrayRef<const Record *> records,
-                           SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
+void mlir::tblgen::collectAllDefs(StringRef selectedDialect,
+                                  ArrayRef<const Record *> records,
+                                  SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
   // Nothing to do if no defs were found.
   if (records.empty())
     return;
@@ -804,42 +804,6 @@ void DefGen::emitStorageClass() {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// This struct is the base generator used when processing tablegen interfaces.
-class DefGenerator {
-public:
-  bool emitDecls(StringRef selectedDialect);
-  bool emitDefs(StringRef selectedDialect);
-
-protected:
-  DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
-               StringRef defType, StringRef valueType, bool isAttrGenerator)
-      : defRecords(defs), os(os), defType(defType), valueType(valueType),
-        isAttrGenerator(isAttrGenerator) {
-    // Sort by occurrence in file.
-    llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
-      return lhs->getID() < rhs->getID();
-    });
-  }
-
-  /// Emit the list of def type names.
-  void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
-  /// Emit the code to dispatch between different defs during parsing/printing.
-  void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
-
-  /// The set of def records to emit.
-  std::vector<const Record *> defRecords;
-  /// The attribute or type class to emit.
-  /// The stream to emit to.
-  raw_ostream &os;
-  /// The prefix of the tablegen def name, e.g. Attr or Type.
-  StringRef defType;
-  /// The C++ base value type of the def, e.g. Attribute or Type.
-  StringRef valueType;
-  /// Flag indicating if this generator is for Attributes. False if the
-  /// generator is for types.
-  bool isAttrGenerator;
-};
-
 /// A specialized generator for AttrDefs.
 struct AttrDefGenerator : public DefGenerator {
   AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
index d4711532a79bb..ca20fdba5ba96 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
@@ -20,6 +20,50 @@ class AttrOrTypeDef;
 void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser,
                               MethodBody &printer);
 
+/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
+/// specified and can only find one dialect's defs, use that.
+void collectAllDefs(StringRef selectedDialect,
+                    ArrayRef<const llvm::Record *> records,
+                    SmallVectorImpl<AttrOrTypeDef> &resultDefs);
+
+/// This struct is the base generator used when processing tablegen interfaces.
+class DefGenerator {
+public:
+  virtual ~DefGenerator() = default;
+  virtual bool emitDecls(StringRef selectedDialect);
+  virtual bool emitDefs(StringRef selectedDialect);
+
+protected:
+  DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
+               StringRef defType, StringRef valueType, bool isAttrGenerator)
+      : defRecords(defs), os(os), defType(defType), valueType(valueType),
+        isAttrGenerator(isAttrGenerator) {
+    // Sort by occurrence in file.
+    llvm::sort(defRecords,
+               [](const llvm::Record *lhs, const llvm::Record *rhs) {
+                 return lhs->getID() < rhs->getID();
+               });
+  }
+
+  /// Emit the list of def type names.
+  void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
+  /// Emit the code to dispatch between different defs during parsing/printing.
+  void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
+
+  /// The set of def records to emit.
+  std::vector<const llvm::Record *> defRecords;
+  /// The attribute or type class to emit.
+  /// The stream to emit to.
+  raw_ostream &os;
+  /// The prefix of the tablegen def name, e.g. Attr or Type.
+  StringRef defType;
+  /// The C++ base value type of the def, e.g. Attribute or Type.
+  StringRef valueType;
+  /// Flag indicating if this generator is for Attributes. False if the
+  /// generator is for types.
+  bool isAttrGenerator;
+};
+
 } // namespace tblgen
 } // namespace mlir
 
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index d7087cba3c874..4256613ce4848 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 add_tablegen(mlir-tblgen MLIR
   DESTINATION "${MLIR_TOOLS_INSTALL_DIR}"
   EXPORT MLIR
+  AttrOrTypeCAPIGen.cpp
   AttrOrTypeDefGen.cpp
   AttrOrTypeFormatGen.cpp
   BytecodeDialectGen.cpp



More information about the Mlir-commits mailing list