[Mlir-commits] [mlir] c3f381c - [mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 19 21:02:29 PDT 2026
Author: Maksim Levental
Date: 2026-03-19T21:02:23-07:00
New Revision: c3f381ccfe4b48f204df07e2c8cd36542a60d553
URL: https://github.com/llvm/llvm-project/commit/c3f381ccfe4b48f204df07e2c8cd36542a60d553
DIFF: https://github.com/llvm/llvm-project/commit/c3f381ccfe4b48f204df07e2c8cd36542a60d553.diff
LOG: [mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)
When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:
```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```
This PR implements a two-pronged fix:
1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.
This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).
Note, this PR was "clauded" not "coded".
Added:
Modified:
mlir/cmake/modules/AddMLIRPython.cmake
mlir/include/mlir/Bindings/Python/Globals.h
mlir/include/mlir/Bindings/Python/IRCore.h
mlir/lib/Bindings/Python/Globals.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/ir.py
mlir/test/mlir-tblgen/enums-python-bindings.td
mlir/test/python/dialects/affine.py
mlir/test/python/dialects/index_dialect.py
mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 6ac5003538e45..07f97e3261a33 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -608,7 +608,7 @@ function(declare_mlir_dialect_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
@@ -680,7 +680,7 @@ function(declare_mlir_dialect_extension_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 8a7f30fd218dc..b2aa169744c97 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -58,11 +58,14 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
bool loadDialectModule(std::string_view dialectNamespace);
/// Adds a user-friendly Attribute builder.
- /// Raises an exception if the mapping already exists and replace == false.
+ /// Raises an exception if the mapping already exists and replace == false
+ /// and allow_existing == false.
+ /// Silently skips registration if allow_existing == true and the mapping
+ /// already exists (first registration wins).
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
- nanobind::callable pyFunc,
- bool replace = false);
+ nanobind::callable pyFunc, bool replace = false,
+ bool allow_existing = false);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 557e32e9a612d..f24b3c6ac6f80 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1356,7 +1356,8 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
static nanobind::callable
dunderGetItemNamed(const std::string &attributeKind);
static void dunderSetItemNamed(const std::string &attributeKind,
- nanobind::callable func, bool replace);
+ nanobind::callable func, bool replace,
+ bool allow_existing);
static void bind(nanobind::module_ &m);
};
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 82195acb9f4fb..1e48eac27dd83 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -97,14 +97,27 @@ bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
- nb::callable pyFunc, bool replace) {
+ nb::callable pyFunc, bool replace,
+ bool allowExisting) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
- if (found && !replace) {
- throw std::runtime_error(
+ if (found) {
+ std::string msg =
nanobind::detail::join("Attribute builder for '", attributeKind,
"' is already registered with func: ",
- nb::cast<std::string>(nb::str(found))));
+ nb::cast<std::string>(nb::str(found)));
+ if (allowExisting) {
+#ifndef NDEBUG
+ if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) {
+ // If the user has set warnings to errors (e.g., via -Werror),
+ // PyErr_WarnEx returns -1 and sets a Python exception.
+ throw nb::python_error();
+ }
+#endif
+ return;
+ }
+ if (!replace)
+ throw std::runtime_error(msg);
}
found = std::move(pyFunc);
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3d07e364b5c98..89e1e21cd1240 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -154,9 +154,10 @@ PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
}
void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
+ nb::callable func, bool replace,
+ bool allow_existing) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
+ replace, allow_existing);
}
void PyAttrBuilderMap::bind(nb::module_ &m) {
@@ -171,6 +172,7 @@ void PyAttrBuilderMap::bind(nb::module_ &m) {
"attribute kind.")
.def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
+ "allow_existing"_a = false,
"Register an attribute builder for building MLIR "
"attributes from Python values.");
}
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 210465daad0d8..3795f5cb2e036 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -95,9 +95,9 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, Non
# Convenience decorator for registering user-friendly Attribute builders.
-def register_attribute_builder(kind, replace=False):
+def register_attribute_builder(kind, replace=False, allow_existing=False):
def decorator_builder(func):
- AttrBuilder.insert(kind, func, replace=replace)
+ AttrBuilder.insert(kind, func, replace=replace, allow_existing=allow_existing)
return func
return decorator_builder
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index cd23b6a2effb9..74b9f51b0c2d6 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -35,7 +35,7 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
// CHECK: return "negone"
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
-// CHECK: @register_attribute_builder("MyEnum")
+// CHECK: @register_attribute_builder("MyEnum", allow_existing=True)
// CHECK: def _myenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
@@ -58,7 +58,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
// CHECK: return "two"
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
-// CHECK: @register_attribute_builder("MyEnum64")
+// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True)
// CHECK: def _myenum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
@@ -102,14 +102,14 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "any"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
-// CHECK: @register_attribute_builder("TestBitEnum")
+// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True)
// CHECK: def _testbitenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
-// CHECK: @register_attribute_builder("TestMyEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
// CHECK: def _testmyenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index c797234fd16d6..1b1655692f15c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -335,6 +335,11 @@ def simple_affine_if_else(cond_operands):
return
+ at constructAndPrintInModule
+def test_double_AtomicRMWKindAttr_registration():
+ from mlir.dialects import _affine_enum_gen
+
+
# CHECK-LABEL: TEST: testAffineIfOpInsertionPoint
@constructAndPrintInModule
def testAffineIfOpInsertionPoint():
diff --git a/mlir/test/python/dialects/index_dialect.py b/mlir/test/python/dialects/index_dialect.py
index 9db883469792c..8da6a262cc441 100644
--- a/mlir/test/python/dialects/index_dialect.py
+++ b/mlir/test/python/dialects/index_dialect.py
@@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
def testCmpOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
- pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
+ pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
r = index.CmpOp(pred, lhs=a, rhs=b)
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index acc9b61d7121c..6cef09d9958c7 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -17,6 +17,7 @@
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/GenInfo.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
@@ -26,6 +27,10 @@ using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
+// Declared in OpPythonBindingGen.cpp; the two generators share the same
+// -bind-dialect option to allow filtering enum registrations by dialect.
+extern std::string dialectNameStorage;
+
/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
@@ -94,7 +99,11 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
return false;
int64_t bitwidth = enumInfo.getBitwidth();
- os << formatv("@register_attribute_builder(\"{0}\")\n",
+ // These builders may be emitted by multiple dialect enum_gen files when
+ // dialects share enum definitions via .td includes. Use allow_existing=True
+ // so that the first loaded dialect registers the builder and subsequent
+ // loads silently skip (first-registration wins).
+ os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n",
enumAttrInfo->getAttrDefName());
os << formatv("def _{0}(x, context):\n",
enumAttrInfo->getAttrDefName().lower());
@@ -108,10 +117,12 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
+static bool emitDialectEnumAttributeBuilder(StringRef dialect,
+ StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
- os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
+ os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
+ attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
@@ -132,6 +143,12 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
+ StringRef dialect = attr.getDialect().getName();
+ // When -bind-dialect is specified, only emit builders for EnumAttr records
+ // belonging to that dialect. This prevents duplicate registrations when
+ // multiple dialects include the same .td files.
+ if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
+ continue;
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
@@ -139,14 +156,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
- attr.getName(),
+ dialect, attr.getName(),
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
- attr.getName(),
+ dialect, attr.getName(),
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index e8acf4ce40fc8..84dce9bdf0c6d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -341,10 +341,13 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
-static llvm::cl::opt<std::string>
+std::string dialectNameStorage;
+
+llvm::cl::opt<std::string, /*ExternalStorage=*/true>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
- llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
+ llvm::cl::location(dialectNameStorage),
+ llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
@@ -887,11 +890,27 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
continue;
}
+ // For EnumAttr-style attributes (those defined as EnumAttr<Dialect, ...>
+ // in tablegen), use a dialect-qualified key ("dialect.AttrName") so the
+ // lookup matches the registration emitted by EnumPythonBindingGen with
+ // -bind-dialect. For all other attributes (plain attrs like I32Attr,
+ // custom AttrDef, etc.), keep the unqualified name to match their
+ // registrations in ir.py or dialect-specific Python files.
+ Attribute baseAttr = attribute->attr.getBaseAttr();
+ Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr")
+ ? baseAttr.getDialect()
+ : Dialect(nullptr);
+ std::string attrBuilderKey = attrDialect
+ ? formatv("{0}.{1}", attrDialect.getName(),
+ attribute->attr.getAttrDefName())
+ .str()
+ : attribute->attr.getAttrDefName().str();
+
builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, attrBuilderKey));
}
}
@@ -1307,18 +1326,18 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
- if (clDialectName.empty())
+ if (dialectNameStorage.empty())
llvm::PrintFatalError("dialect name not provided");
os << fileHeader;
if (!clDialectExtensionName.empty())
- os << formatv(dialectExtensionTemplate, clDialectName.getValue());
+ os << formatv(dialectExtensionTemplate, dialectNameStorage);
else
- os << formatv(dialectClassTemplate, clDialectName.getValue());
+ os << formatv(dialectClassTemplate, dialectNameStorage);
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
- if (op.getDialectName() == clDialectName.getValue())
+ if (op.getDialectName() == dialectNameStorage)
emitOpBindings(op, os);
}
return false;
More information about the Mlir-commits
mailing list