[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)

Maksim Levental llvmlistbot at llvm.org
Thu Jan 11 12:56:54 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/77830

>From 26a8d9c422239870d9429abcfdb524bf8569c014 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 11 Jan 2024 14:56:18 -0600
Subject: [PATCH] namespace generated enums in python

---
 .../mlir-tblgen/EnumPythonBindingGen.cpp      | 54 +++++++++----------
 1 file changed, 26 insertions(+), 28 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..9a1961482d7d09 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -19,6 +19,8 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
+#include <llvm/IR/DebugInfo.h>
+
 using namespace mlir;
 using namespace mlir::tblgen;
 
@@ -105,7 +107,8 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+  os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+                      enumAttr.getDialect().getName(),
                       enumAttr.getAttrDefName());
   os << llvm::formatv("def _{0}(x, context):\n",
                       enumAttr.getAttrDefName().lower());
@@ -120,11 +123,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
-                                            StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
                                             raw_ostream &os) {
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
-  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+  StringRef mnemonic = attr.getMnemonic().value();
+  std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+  StringRef dialect = attr.getDialect().getName();
+  std::string formatString;
+  if (assemblyFormat == "`<` $value `>`") {
+    formatString =
+        llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+  } else if (assemblyFormat == "$value") {
+    formatString =
+        llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+  } else {
+    llvm::errs()
+        << "unsupported assembly format for python enum bindings generation";
+    return true;
+  }
+  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attr.getName());
+  os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
   os << llvm::formatv("    return "
                       "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                       formatString);
@@ -142,29 +159,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
     emitEnumClass(enumAttr, os);
     emitAttributeBuilder(enumAttr, os);
   }
-  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
-    AttrOrTypeDef attr(&*it);
-    if (!attr.getMnemonic()) {
-      llvm::errs() << "enum case " << attr
-                   << " needs mnemonic for python enum bindings generation";
-      return true;
-    }
-    StringRef mnemonic = attr.getMnemonic().value();
-    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
-    if (assemblyFormat == "`<` $value `>`") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
-    } else if (assemblyFormat == "$value") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
-    } else {
-      llvm::errs()
-          << "unsupported assembly format for python enum bindings generation";
-      return true;
-    }
+  for (const auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    const AttrOrTypeDef attr(&*it);
+    return emitDialectEnumAttributeBuilder(attr, os);
   }
 
   return false;



More information about the Mlir-commits mailing list