[Mlir-commits] [mlir] c3f381c - [mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 19 21:02:29 PDT 2026


Author: Maksim Levental
Date: 2026-03-19T21:02:23-07:00
New Revision: c3f381ccfe4b48f204df07e2c8cd36542a60d553

URL: https://github.com/llvm/llvm-project/commit/c3f381ccfe4b48f204df07e2c8cd36542a60d553
DIFF: https://github.com/llvm/llvm-project/commit/c3f381ccfe4b48f204df07e2c8cd36542a60d553.diff

LOG: [mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)

When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:

```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```

This PR implements a two-pronged fix:

1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.

This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).

Note, this PR was "clauded" not "coded".

Added: 
    

Modified: 
    mlir/cmake/modules/AddMLIRPython.cmake
    mlir/include/mlir/Bindings/Python/Globals.h
    mlir/include/mlir/Bindings/Python/IRCore.h
    mlir/lib/Bindings/Python/Globals.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/ir.py
    mlir/test/mlir-tblgen/enums-python-bindings.td
    mlir/test/python/dialects/affine.py
    mlir/test/python/dialects/index_dialect.py
    mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 6ac5003538e45..07f97e3261a33 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -608,7 +608,7 @@ function(declare_mlir_dialect_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 
@@ -680,7 +680,7 @@ function(declare_mlir_dialect_extension_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 

diff  --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 8a7f30fd218dc..b2aa169744c97 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -58,11 +58,14 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
   bool loadDialectModule(std::string_view dialectNamespace);
 
   /// Adds a user-friendly Attribute builder.
-  /// Raises an exception if the mapping already exists and replace == false.
+  /// Raises an exception if the mapping already exists and replace == false
+  /// and allow_existing == false.
+  /// Silently skips registration if allow_existing == true and the mapping
+  /// already exists (first registration wins).
   /// This is intended to be called by implementation code.
   void registerAttributeBuilder(const std::string &attributeKind,
-                                nanobind::callable pyFunc,
-                                bool replace = false);
+                                nanobind::callable pyFunc, bool replace = false,
+                                bool allow_existing = false);
 
   /// Adds a user-friendly type caster. Raises an exception if the mapping
   /// already exists and replace == false. This is intended to be called by

diff  --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 557e32e9a612d..f24b3c6ac6f80 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1356,7 +1356,8 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
   static nanobind::callable
   dunderGetItemNamed(const std::string &attributeKind);
   static void dunderSetItemNamed(const std::string &attributeKind,
-                                 nanobind::callable func, bool replace);
+                                 nanobind::callable func, bool replace,
+                                 bool allow_existing);
 
   static void bind(nanobind::module_ &m);
 };

diff  --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 82195acb9f4fb..1e48eac27dd83 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -97,14 +97,27 @@ bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
-                                         nb::callable pyFunc, bool replace) {
+                                         nb::callable pyFunc, bool replace,
+                                         bool allowExisting) {
   nb::ft_lock_guard lock(mutex);
   nb::object &found = attributeBuilderMap[attributeKind];
-  if (found && !replace) {
-    throw std::runtime_error(
+  if (found) {
+    std::string msg =
         nanobind::detail::join("Attribute builder for '", attributeKind,
                                "' is already registered with func: ",
-                               nb::cast<std::string>(nb::str(found))));
+                               nb::cast<std::string>(nb::str(found)));
+    if (allowExisting) {
+#ifndef NDEBUG
+      if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) {
+        // If the user has set warnings to errors (e.g., via -Werror),
+        // PyErr_WarnEx returns -1 and sets a Python exception.
+        throw nb::python_error();
+      }
+#endif
+      return;
+    }
+    if (!replace)
+      throw std::runtime_error(msg);
   }
   found = std::move(pyFunc);
 }

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3d07e364b5c98..89e1e21cd1240 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -154,9 +154,10 @@ PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
 }
 
 void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
-                                          nb::callable func, bool replace) {
+                                          nb::callable func, bool replace,
+                                          bool allow_existing) {
   PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
-                                            replace);
+                                            replace, allow_existing);
 }
 
 void PyAttrBuilderMap::bind(nb::module_ &m) {
@@ -171,6 +172,7 @@ void PyAttrBuilderMap::bind(nb::module_ &m) {
                   "attribute kind.")
       .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
                   "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
+                  "allow_existing"_a = false,
                   "Register an attribute builder for building MLIR "
                   "attributes from Python values.");
 }

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 210465daad0d8..3795f5cb2e036 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -95,9 +95,9 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, Non
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
-def register_attribute_builder(kind, replace=False):
+def register_attribute_builder(kind, replace=False, allow_existing=False):
     def decorator_builder(func):
-        AttrBuilder.insert(kind, func, replace=replace)
+        AttrBuilder.insert(kind, func, replace=replace, allow_existing=allow_existing)
         return func
 
     return decorator_builder

diff  --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index cd23b6a2effb9..74b9f51b0c2d6 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -35,7 +35,7 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
 // CHECK:             return "negone"
 // CHECK:         raise ValueError("Unknown MyEnum enum entry.")
 
-// CHECK: @register_attribute_builder("MyEnum")
+// CHECK: @register_attribute_builder("MyEnum", allow_existing=True)
 // CHECK: def _myenum(x, context):
 // CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
 
@@ -58,7 +58,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
 // CHECK:             return "two"
 // CHECK:         raise ValueError("Unknown MyEnum64 enum entry.")
 
-// CHECK: @register_attribute_builder("MyEnum64")
+// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True)
 // CHECK: def _myenum64(x, context):
 // CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
 
@@ -102,14 +102,14 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
 // CHECK:             return "any"
 // CHECK:         raise ValueError("Unknown TestBitEnum enum entry.")
 
-// CHECK: @register_attribute_builder("TestBitEnum")
+// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True)
 // CHECK: def _testbitenum(x, context):
 // CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
 
-// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
 // CHECK: def _testbitenum_attr(x, context):
 // CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
 
-// CHECK: @register_attribute_builder("TestMyEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
 // CHECK: def _testmyenum_attr(x, context):
 // CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)

diff  --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index c797234fd16d6..1b1655692f15c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -335,6 +335,11 @@ def simple_affine_if_else(cond_operands):
         return
 
 
+ at constructAndPrintInModule
+def test_double_AtomicRMWKindAttr_registration():
+    from mlir.dialects import _affine_enum_gen
+
+
 # CHECK-LABEL: TEST: testAffineIfOpInsertionPoint
 @constructAndPrintInModule
 def testAffineIfOpInsertionPoint():

diff  --git a/mlir/test/python/dialects/index_dialect.py b/mlir/test/python/dialects/index_dialect.py
index 9db883469792c..8da6a262cc441 100644
--- a/mlir/test/python/dialects/index_dialect.py
+++ b/mlir/test/python/dialects/index_dialect.py
@@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
 def testCmpOp(ctx):
     a = index.ConstantOp(value=42)
     b = index.ConstantOp(value=23)
-    pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
+    pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
     r = index.CmpOp(pred, lhs=a, rhs=b)
     # CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
 

diff  --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index acc9b61d7121c..6cef09d9958c7 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
@@ -26,6 +27,10 @@ using llvm::formatv;
 using llvm::Record;
 using llvm::RecordKeeper;
 
+// Declared in OpPythonBindingGen.cpp; the two generators share the same
+// -bind-dialect option to allow filtering enum registrations by dialect.
+extern std::string dialectNameStorage;
+
 /// File header and includes.
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
@@ -94,7 +99,11 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
     return false;
 
   int64_t bitwidth = enumInfo.getBitwidth();
-  os << formatv("@register_attribute_builder(\"{0}\")\n",
+  // These builders may be emitted by multiple dialect enum_gen files when
+  // dialects share enum definitions via .td includes. Use allow_existing=True
+  // so that the first loaded dialect registers the builder and subsequent
+  // loads silently skip (first-registration wins).
+  os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n",
                 enumAttrInfo->getAttrDefName());
   os << formatv("def _{0}(x, context):\n",
                 enumAttrInfo->getAttrDefName().lower());
@@ -108,10 +117,12 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
+static bool emitDialectEnumAttributeBuilder(StringRef dialect,
+                                            StringRef attrDefName,
                                             StringRef formatString,
                                             raw_ostream &os) {
-  os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
+  os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
+                attrDefName);
   os << formatv("def _{0}(x, context):\n", attrDefName.lower());
   os << formatv("    return "
                 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
@@ -132,6 +143,12 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
     AttrOrTypeDef attr(&*it);
+    StringRef dialect = attr.getDialect().getName();
+    // When -bind-dialect is specified, only emit builders for EnumAttr records
+    // belonging to that dialect. This prevents duplicate registrations when
+    // multiple dialects include the same .td files.
+    if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
+      continue;
     if (!attr.getMnemonic()) {
       llvm::errs() << "enum case " << attr
                    << " needs mnemonic for python enum bindings generation";
@@ -139,14 +156,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
     }
     StringRef mnemonic = attr.getMnemonic().value();
     std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
     if (assemblyFormat == "`<` $value `>`") {
       emitDialectEnumAttributeBuilder(
-          attr.getName(),
+          dialect, attr.getName(),
           formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
     } else if (assemblyFormat == "$value") {
       emitDialectEnumAttributeBuilder(
-          attr.getName(),
+          dialect, attr.getName(),
           formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
     } else {
       llvm::errs()

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index e8acf4ce40fc8..84dce9bdf0c6d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -341,10 +341,13 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
 
-static llvm::cl::opt<std::string>
+std::string dialectNameStorage;
+
+llvm::cl::opt<std::string, /*ExternalStorage=*/true>
     clDialectName("bind-dialect",
                   llvm::cl::desc("The dialect to run the generator for"),
-                  llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
+                  llvm::cl::location(dialectNameStorage),
+                  llvm::cl::cat(clOpPythonBindingCat));
 
 static llvm::cl::opt<std::string> clDialectExtensionName(
     "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
@@ -887,11 +890,27 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
       continue;
     }
 
+    // For EnumAttr-style attributes (those defined as EnumAttr<Dialect, ...>
+    // in tablegen), use a dialect-qualified key ("dialect.AttrName") so the
+    // lookup matches the registration emitted by EnumPythonBindingGen with
+    // -bind-dialect. For all other attributes (plain attrs like I32Attr,
+    // custom AttrDef, etc.), keep the unqualified name to match their
+    // registrations in ir.py or dialect-specific Python files.
+    Attribute baseAttr = attribute->attr.getBaseAttr();
+    Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr")
+                              ? baseAttr.getDialect()
+                              : Dialect(nullptr);
+    std::string attrBuilderKey = attrDialect
+                                     ? formatv("{0}.{1}", attrDialect.getName(),
+                                               attribute->attr.getAttrDefName())
+                                           .str()
+                                     : attribute->attr.getAttrDefName().str();
+
     builderLines.push_back(formatv(
         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
             ? initOptionalAttributeWithBuilderTemplate
             : initAttributeWithBuilderTemplate,
-        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+        argNames[i], attribute->name, attrBuilderKey));
   }
 }
 
@@ -1307,18 +1326,18 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
 /// headers and utilities. Returns `false` on success to comply with Tablegen
 /// registration requirements.
 static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
-  if (clDialectName.empty())
+  if (dialectNameStorage.empty())
     llvm::PrintFatalError("dialect name not provided");
 
   os << fileHeader;
   if (!clDialectExtensionName.empty())
-    os << formatv(dialectExtensionTemplate, clDialectName.getValue());
+    os << formatv(dialectExtensionTemplate, dialectNameStorage);
   else
-    os << formatv(dialectClassTemplate, clDialectName.getValue());
+    os << formatv(dialectClassTemplate, dialectNameStorage);
 
   for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
     Operator op(rec);
-    if (op.getDialectName() == clDialectName.getValue())
+    if (op.getDialectName() == dialectNameStorage)
       emitOpBindings(op, os);
   }
   return false;


        


More information about the Mlir-commits mailing list