[Mlir-commits] [mlir] [mlir][python] namespace generated enums in python (PR #77830)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 11 14:42:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

A recent PR broke enum bindings generation because of collision of attrs with the same name across dialects: https://github.com/llvm/llvm-project/pull/77211#discussion_r1449317601. So we need to namespace these now. In the current form of the PR, this is a breaking change (anyone that is supplying their own attribute builders won't have the `<dialect>_attr` prefix. I can rewrite to just check for both `<dialect>_attr` and just `attr` in the `AttrBuilder` queries but I don't know what people think is best.

Will add/update tests shortly.

---
Full diff: https://github.com/llvm/llvm-project/pull/77830.diff


2 Files Affected:

- (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+25-28) 
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+10-7) 


``````````diff
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index f4ced0803772ed..d5f36ff3bc0fd2 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -105,7 +105,8 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
+  os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+                      enumAttr.getDialect().getName(),
                       enumAttr.getAttrDefName());
   os << llvm::formatv("def _{0}(x, context):\n",
                       enumAttr.getAttrDefName().lower());
@@ -120,11 +121,26 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
-                                            StringRef formatString,
+static bool emitDialectEnumAttributeBuilder(const AttrOrTypeDef &attr,
                                             raw_ostream &os) {
-  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
-  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+  StringRef mnemonic = attr.getMnemonic().value();
+  std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+  StringRef dialect = attr.getDialect().getName();
+  std::string formatString;
+  if (assemblyFormat == "`<` $value `>`") {
+    formatString =
+        llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str();
+  } else if (assemblyFormat == "$value") {
+    formatString =
+        llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str();
+  } else {
+    llvm::errs()
+        << "unsupported assembly format for python enum bindings generation";
+    return true;
+  }
+  os << llvm::formatv("@register_attribute_builder(\"{0}_{1}\")\n",
+                      attr.getDialect().getName(), attr.getName());
+  os << llvm::formatv("def _{0}(x, context):\n", attr.getName().lower());
   os << llvm::formatv("    return "
                       "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
                       formatString);
@@ -142,29 +158,10 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
     emitEnumClass(enumAttr, os);
     emitAttributeBuilder(enumAttr, os);
   }
-  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
-    AttrOrTypeDef attr(&*it);
-    if (!attr.getMnemonic()) {
-      llvm::errs() << "enum case " << attr
-                   << " needs mnemonic for python enum bindings generation";
-      return true;
-    }
-    StringRef mnemonic = attr.getMnemonic().value();
-    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
-    if (assemblyFormat == "`<` $value `>`") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
-    } else if (assemblyFormat == "$value") {
-      emitDialectEnumAttributeBuilder(
-          attr.getName(),
-          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
-    } else {
-      llvm::errs()
-          << "unsupported assembly format for python enum bindings generation";
-      return true;
-    }
+  for (const auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    const AttrOrTypeDef attr(&*it);
+    return emitDialectEnumAttributeBuilder(attr, os);
   }
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..de343df1c434fa 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -529,27 +529,29 @@ constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 /// Template for attribute builder from raw input in the operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initAttributeWithBuilderTemplate =
     R"Py(attributes["{1}"] = ({0} if (
     issubclass(type({0}), _ods_ir.Attribute) or
-    not _ods_ir.AttrBuilder.contains('{2}')) else
-      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+    not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+      _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
 
 /// Template for attribute builder from raw input for optional attribute in the
 /// operation builder.
 ///   {0} is the builder argument name;
 ///   {1} is the attribute builder from raw;
-///   {2} is the attribute builder from raw.
+///   {2} is the attribute builder from raw;
+///   {3} is the attribute's dialect.
 /// Use the value the user passed in if either it is already an Attribute or
 /// there is no method registered to make it an Attribute.
 constexpr const char *initOptionalAttributeWithBuilderTemplate =
     R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
         issubclass(type({0}), _ods_ir.Attribute) or
-        not _ods_ir.AttrBuilder.contains('{2}')) else
-          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+        not _ods_ir.AttrBuilder.contains('{2}_{3}')) else
+          _ods_ir.AttrBuilder.get('{2}_{3}')({0}, context=_ods_context)))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -681,7 +683,8 @@ populateBuilderLinesAttr(const Operator &op,
         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
             ? initOptionalAttributeWithBuilderTemplate
             : initAttributeWithBuilderTemplate,
-        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+        argNames[i], attribute->name, attribute->attr.getAttrDefName(),
+        attribute->attr.getDialect().getName()));
   }
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/77830


More information about the Mlir-commits mailing list