[Mlir-commits] [mlir] [MLIR][Python] Extend bindings for external projects without duplication (PR #173241)

Sylvain Noiry llvmlistbot at llvm.org
Tue Dec 23 00:58:03 PST 2025


https://github.com/ElectrikSpace updated https://github.com/llvm/llvm-project/pull/173241

>From fff6f57b470756ce585bd79e156daec53ad7306e Mon Sep 17 00:00:00 2001
From: Sylvain Noiry <snoiry at kalrayinc.com>
Date: Mon, 22 Dec 2025 10:25:52 +0100
Subject: [PATCH] [MLIR][Python] Extend bindings without duplication

Add a MLIR_PYTHON_PACKAGE option to declare_mlir_dialect_python_bindings
and declare_mlir_dialect_extension_python_bindings to create bindings
as and extension to a main (usually upstream) mlir package. This changes
the import paths in Python files generated by mlir-tblgen.

MLIRPythonCAPI is now installed and can be linked dynamically in
add_mlir_python_modules to prevent mangling issues due to code
duplication.
---
 mlir/cmake/modules/AddMLIRPython.cmake        | 25 ++++++++++++++-----
 .../examples/standalone/python/CMakeLists.txt |  3 ++-
 .../test/mlir-tblgen/enums-python-bindings.td |  4 +--
 .../mlir-tblgen/EnumPythonBindingGen.cpp      | 10 +++++---
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 23 ++++++++++++-----
 5 files changed, 47 insertions(+), 18 deletions(-)

diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index ca90151e76268..2d86a9a9d0bc4 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -389,6 +389,8 @@ endfunction()
 #     This file is where the *EnumAttrs are defined, not where the *Enums are defined.
 #     **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your
 #     risk.
+#   MLIR_PYTHON_PACKAGE: Optional name of the current main MLIR package.
+#     It can be used to build extensions against a main package.
 #
 # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we
 #       use its path to determine where to place the generated python file. If
@@ -397,7 +399,7 @@ endfunction()
 function(declare_mlir_dialect_python_bindings)
   cmake_parse_arguments(ARG
     "GEN_ENUM_BINDINGS"
-    "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME"
+    "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;MLIR_PYTHON_PACKAGE"
     "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
     ${ARGN})
   # Sources.
@@ -417,8 +419,12 @@ function(declare_mlir_dialect_python_bindings)
     file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
     set(dialect_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_ops_gen.py")
     set(LLVM_TARGET_DEFINITIONS ${td_file})
+    if(NOT DEFINED ARG_MLIR_PYTHON_PACKAGE)
+      set(ARG_MLIR_PYTHON_PACKAGE "mlir")
+    endif()
     mlir_tablegen("${dialect_filename}"
       -gen-python-op-bindings -bind-dialect=${ARG_DIALECT_NAME}
+      -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE}
       DEPENDS ${ARG_DEPENDS}
     )
     add_public_tablegen_target(${tblgen_target})
@@ -430,7 +436,8 @@ function(declare_mlir_dialect_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings
+		    -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE})
       list(APPEND _sources ${enum_filename})
     endif()
 
@@ -464,10 +471,12 @@ endfunction()
 #     This file is where the *Attrs are defined, not where the *Enums are defined.
 #     **WARNING**: This arg will shortly be removed when the TODO for
 #     declare_mlir_dialect_python_bindings is satisfied. Use at your risk.
+#   MLIR_PYTHON_PACKAGE: Optional name of the current main MLIR package.
+#     It can be used to build extensions against a main package.
 function(declare_mlir_dialect_extension_python_bindings)
   cmake_parse_arguments(ARG
     "GEN_ENUM_BINDINGS"
-    "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
+    "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME;MLIR_PYTHON_PACKAGE"
     "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
     ${ARGN})
   # Source files.
@@ -487,9 +496,13 @@ function(declare_mlir_dialect_extension_python_bindings)
     file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
     set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
     set(LLVM_TARGET_DEFINITIONS ${td_file})
+    if(NOT DEFINED ARG_MLIR_PYTHON_PACKAGE)
+      set(ARG_MLIR_PYTHON_PACKAGE "mlir")
+    endif()
     mlir_tablegen("${output_filename}" -gen-python-op-bindings
                   -bind-dialect=${ARG_DIALECT_NAME}
-                  -dialect-extension=${ARG_EXTENSION_NAME})
+                  -dialect-extension=${ARG_EXTENSION_NAME}
+		  -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE})
     add_public_tablegen_target(${tblgen_target})
     if(ARG_DEPENDS)
       add_dependencies(${tblgen_target} ${ARG_DEPENDS})
@@ -502,7 +515,8 @@ function(declare_mlir_dialect_extension_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings
+		    -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE})
       list(APPEND _sources ${enum_filename})
     endif()
 
@@ -601,7 +615,6 @@ function(add_mlir_python_common_capi_library name)
   # Generate the aggregate .so that everything depends on.
   add_mlir_aggregate(${name}
     SHARED
-    DISABLE_INSTALL
     EMBED_LIBS ${_embed_libs}
   )
 
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index edaedf18cc843..91373da0e9377 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -18,7 +18,8 @@ declare_mlir_dialect_python_bindings(
   SOURCES
     dialects/standalone_nanobind.py
     _mlir_libs/_standaloneDialectsNanobind/py.typed
-  DIALECT_NAME standalone)
+  DIALECT_NAME standalone
+  MLIR_PYTHON_PACKAGE "${MLIR_PYTHON_PACKAGE_PREFIX}")
 
 declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
   MODULE_NAME _standaloneDialectsNanobind
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index cd23b6a2effb9..5dd002ca21bd3 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -10,8 +10,8 @@ def Test_Dialect : Dialect {
 // CHECK: Autogenerated by mlir-tblgen; don't manually edit.
 
 // CHECK: from enum import IntEnum, auto, IntFlag
-// CHECK: from ._ods_common import _cext as _ods_cext
-// CHECK: from ..ir import register_attribute_builder
+// CHECK: from mlir.dialects._ods_common import _cext as _ods_cext
+// CHECK: from mlir.ir import register_attribute_builder
 // CHECK: _ods_ir = _ods_cext.ir
 
 def One : I32EnumAttrCase<"CaseOne", 1, "one">;
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index acc9b61d7121c..31211a2094f06 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
@@ -27,16 +28,19 @@ using llvm::Record;
 using llvm::RecordKeeper;
 
 /// File header and includes.
+///   {0} is the Python package prefix.
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from enum import IntEnum, auto, IntFlag
-from ._ods_common import _cext as _ods_cext
-from ..ir import register_attribute_builder
+from {0}.dialects._ods_common import _cext as _ods_cext
+from {0}.ir import register_attribute_builder
 _ods_ir = _ods_cext.ir
 
 )Py";
 
+extern llvm::cl::opt<std::string> clPythonPackagePrefix;
+
 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
 static std::string makePythonEnumCaseName(StringRef name) {
   if (isPythonReserved(name.str()))
@@ -122,7 +126,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
 /// Emits Python bindings for all enums in the record keeper. Returns
 /// `false` on success, `true` on failure.
 static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
-  os << fileHeader;
+  os << formatv(fileHeader, clPythonPackagePrefix);
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
     EnumInfo enumInfo(*it);
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c33f4efac3ac..fd3a82427352d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,12 +30,12 @@ using llvm::Record;
 using llvm::RecordKeeper;
 
 /// File header and includes.
-///   {0} is the dialect namespace.
+///   {0} is the Python package prefix.
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
-from ._ods_common import _cext as _ods_cext
-from ._ods_common import (
+from {0}.dialects._ods_common import _cext as _ods_cext
+from {0}.dialects._ods_common import (
     equally_sized_accessor as _ods_equally_sized_accessor,
     get_default_loc_context as _ods_get_default_loc_context,
     get_op_results_or_values as _get_op_results_or_values,
@@ -51,6 +51,7 @@ from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
 
 /// Template for dialect class:
 ///   {0} is the dialect namespace.
+///   {1} is the Python package prefix.
 constexpr const char *dialectClassTemplate = R"Py(
 @_ods_cext.register_dialect
 class _Dialect(_ods_ir.Dialect):
@@ -58,7 +59,7 @@ class _Dialect(_ods_ir.Dialect):
 )Py";
 
 constexpr const char *dialectExtensionTemplate = R"Py(
-from ._{0}_ops_gen import _Dialect
+from {1}.dialects._{0}_ops_gen import _Dialect
 )Py";
 
 /// Template for operation class:
@@ -293,6 +294,15 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
   return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
 )Py";
 
+static llvm::cl::OptionCategory
+    clPythonBindingCat("Options for -gen-python-op-bindings and "
+	                 "-gen-python-enum-bindings");
+
+llvm::cl::opt<std::string>
+    clPythonPackagePrefix("python-package-prefix",
+                  llvm::cl::desc("The prefix of the MLIR Python package"),
+                  llvm::cl::init("mlir"), llvm::cl::cat(clPythonBindingCat));
+
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
 
@@ -1222,9 +1232,10 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
   if (clDialectName.empty())
     llvm::PrintFatalError("dialect name not provided");
 
-  os << fileHeader;
+  os << formatv(fileHeader, clPythonPackagePrefix.getValue());
   if (!clDialectExtensionName.empty())
-    os << formatv(dialectExtensionTemplate, clDialectName.getValue());
+    os << formatv(dialectExtensionTemplate, clDialectName.getValue(),
+		  clPythonPackagePrefix.getValue());
   else
     os << formatv(dialectClassTemplate, clDialectName.getValue());
 



More information about the Mlir-commits mailing list