[Mlir-commits] [mlir] b544ad5 - [MLIR] [Python] Added a way to extend MLIR->Python type mappings (#189368)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 31 03:00:46 PDT 2026


Author: Sergei Lebedev
Date: 2026-03-31T11:00:40+01:00
New Revision: b544ad57039588d0fe24a1f512202cc5c0bd3a67

URL: https://github.com/llvm/llvm-project/commit/b544ad57039588d0fe24a1f512202cc5c0bd3a67
DIFF: https://github.com/llvm/llvm-project/commit/b544ad57039588d0fe24a1f512202cc5c0bd3a67.diff

LOG: [MLIR] [Python] Added a way to extend MLIR->Python type mappings (#189368)

The idea is to use TableGen records for both custom type constraints and
attributes:

* `PythonTypeName` is for type constraints, while
* `PythonAttrType` is for attributes.

The key types differ between these two records. `PythonTypeName` is
keyed by C++ type because multiple type constraints map to the same C++
type (e.g. `I32` and `I64` both map to `::mlir::IntegerType`), so a
single entry covers all of them. `PythonAttrType` is keyed by TableGen
def name because different attributes can share the same C++ storage
type but need distinct Python types (e.g. `I32ArrayAttr` and
`StrArrayAttr` are both `::mlir::ArrayAttr`).

We could in theory reimplement `getPythonAttrName` using the same
approach, but I decided to leave it for future PRs.

Added: 
    mlir/include/mlir/Bindings/Python/PythonBindings.td

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index df186e43219c2..c43181859f968 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1180,6 +1180,29 @@ filled with `import`s from the generated files to enable `import
 mlir.dialects.<dialect-namespace>` in Python.
 
 
+#### Customizing Type Annotations
+
+The generated `__init__` methods include type annotations for operand and
+attribute arguments. Built-in mappings cover standard MLIR types and attributes,
+but dialects can extend them by adding definitions from
+[`PythonBindings.td`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PythonBindings.td)
+to the same `.td` file passed to `mlir-tblgen -gen-python-op-bindings`:
+
+```tablegen
+include "mlir/Bindings/Python/PythonBindings.td"
+
+// Operand/result annotations: maps a C++ type from the ODS type constraint's
+// cppClassName to a Python type annotation, e.g.
+// `ir.Value` -> `ir.Value[my_dialect.MyTensorType]`.
+def : PythonTypeName<"::my_dialect::MyTensorType",
+                     "my_dialect.MyTensorType">;
+
+// Attribute annotations: maps a TableGen attribute def name to the Python
+// type accepted by its AttrBuilder, e.g.
+// `Union[Any, ir.Attribute]` -> `Union[my_dialect.MyValue, ir.Attribute]`.
+def : PythonAttrType<"MyCustomAttr", "my_dialect.MyValue">;
+```
+
 ### Attributes and Types
 
 Dialect attributes and types are provided in Python as subclasses of the

diff  --git a/mlir/include/mlir/Bindings/Python/PythonBindings.td b/mlir/include/mlir/Bindings/Python/PythonBindings.td
new file mode 100644
index 0000000000000..b652b1381b6d6
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/PythonBindings.td
@@ -0,0 +1,31 @@
+//===-- PythonBindings.td - Python binding type mappings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C++ type and attribute to Python type mappings for -gen-python-op-bindings.
+// Dialects can include this file and add their own mappings.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TD
+#define PYTHON_BINDINGS_TD
+
+/// Maps a C++ type to a Python type annotation for operands and results.
+/// Example: PythonTypeName<"::mlir::IntegerType", "_ods_ir.IntegerType">
+class PythonTypeName<string cppType, string pythonType> {
+  string cppName = cppType;
+  string pyName = pythonType;
+}
+
+/// Maps a TableGen attribute def name to the Python type accepted by its
+/// AttrBuilder. Example: PythonAttrType<"I32Attr", "int">
+class PythonAttrType<string attrDefName, string pythonType> {
+  string defName = attrDefName;
+  string pyType = pythonType;
+}
+
+#endif // PYTHON_BINDINGS_TD

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 5e29f3f61e5c8..f566290915dde 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -3,6 +3,7 @@
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Bindings/Python/PythonBindings.td"
 
 // CHECK: @_ods_cext.register_dialect
 // CHECK: class _Dialect(_ods_ir.Dialect):
@@ -15,6 +16,15 @@ def Test_Dialect : Dialect {
 class TestOp<string mnemonic, list<Trait> traits = []> :
     Op<Test_Dialect, mnemonic, traits>;
 
+def TestCustomAttr
+    : DialectAttr<Test_Dialect,
+                  CPred<"::llvm::isa<::test::CustomAttr>($_self)">,
+                  "custom attribute"> {
+  let storageType = "::test::CustomAttr";
+}
+
+def : PythonAttrType<"TestCustomAttr", "test.CustomType">;
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.attr_sized_operands"
@@ -217,6 +227,12 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 // CHECK:   return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class CustomAttrOp(_ods_ir.OpView):
+def CustomAttrOp : TestOp<"custom_attr"> {
+  // CHECK: def __init__(self, custom: _Union[test.CustomType, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
+  let arguments = (ins TestCustomAttr:$custom);
+}
+
 // CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.default_valued_attrs"
 def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 14500edca5bcf..81c598ebbef0a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -30,6 +31,78 @@ using llvm::formatv;
 using llvm::Record;
 using llvm::RecordKeeper;
 
+/// Built-in C++ type to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinTypeMappings[] = {
+    {"::mlir::MemRefType", "_ods_ir.MemRefType"},
+    {"::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType"},
+    {"::mlir::RankedTensorType", "_ods_ir.RankedTensorType"},
+    {"::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType"},
+    {"::mlir::VectorType", "_ods_ir.VectorType"},
+    {"::mlir::IntegerType", "_ods_ir.IntegerType"},
+    {"::mlir::FloatType", "_ods_ir.FloatType"},
+    {"::mlir::IndexType", "_ods_ir.IndexType"},
+    {"::mlir::ComplexType", "_ods_ir.ComplexType"},
+    {"::mlir::TupleType", "_ods_ir.TupleType"},
+    {"::mlir::NoneType", "_ods_ir.NoneType"},
+};
+
+/// Built-in TableGen attribute def name to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinAttrTypeMappings[] = {
+    {"BoolAttr", "bool"},
+    {"I1Attr", "bool"},
+    {"I8Attr", "int"},
+    {"I16Attr", "int"},
+    {"I32Attr", "int"},
+    {"I64Attr", "int"},
+    {"SI1Attr", "int"},
+    {"SI8Attr", "int"},
+    {"SI16Attr", "int"},
+    {"SI32Attr", "int"},
+    {"SI64Attr", "int"},
+    {"UI1Attr", "int"},
+    {"UI8Attr", "int"},
+    {"UI16Attr", "int"},
+    {"UI32Attr", "int"},
+    {"UI64Attr", "int"},
+    {"IndexAttr", "int"},
+    {"F32Attr", "float"},
+    {"F64Attr", "float"},
+    {"StrAttr", "str"},
+    {"SymbolNameAttr", "str"},
+    {"FlatSymbolRefAttr", "str"},
+    {"SymbolRefAttr", "str"},
+    {"TypeAttr", "_ods_ir.Type"},
+    {"AffineMapAttr", "_ods_ir.AffineMap"},
+    {"IntegerSetAttr", "_ods_ir.IntegerSet"},
+    {"DictionaryAttr", "dict"},
+    {"ArrayAttr", "_Sequence[_ods_ir.Attribute]"},
+    {"I32ArrayAttr", "_Sequence[int]"},
+    {"I64ArrayAttr", "_Sequence[int]"},
+    {"I64SmallVectorArrayAttr", "_Sequence[int]"},
+    {"F32ArrayAttr", "_Sequence[float]"},
+    {"F64ArrayAttr", "_Sequence[float]"},
+    {"BoolArrayAttr", "_Sequence[bool]"},
+    {"DenseBoolArrayAttr", "_Sequence[bool]"},
+    {"StrArrayAttr", "_Sequence[str]"},
+    {"FlatSymbolRefArrayAttr", "_Sequence[str]"},
+    {"DenseI8ArrayAttr", "_Sequence[int]"},
+    {"DenseI16ArrayAttr", "_Sequence[int]"},
+    {"DenseI32ArrayAttr", "_Sequence[int]"},
+    {"DenseI64ArrayAttr", "_Sequence[int]"},
+    {"DenseF32ArrayAttr", "_Sequence[float]"},
+    {"DenseF64ArrayAttr", "_Sequence[float]"},
+    {"I32ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"I64ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"IndexElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"F64ElementsAttr", "_Union[_Sequence[float], _Buffer]"},
+};
+
+/// Maps from C++ type names to Python type annotations.
+static llvm::StringMap<std::string> pythonTypeMap;
+
+/// Maps from TableGen attribute def names to Python types.
+static llvm::StringMap<std::string> pythonAttrTypeMap;
+
 /// File header and includes.
 ///   {0} is the dialect namespace.
 constexpr const char *fileHeader = R"Py(
@@ -397,20 +470,13 @@ static std::string attrSizedTraitForKind(const char *kind) {
                  StringRef(kind).drop_front());
 }
 
-static StringRef getPythonType(StringRef cppType) {
-  return llvm::StringSwitch<StringRef>(cppType)
-      .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
-      .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
-      .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
-      .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
-      .Case("::mlir::VectorType", "_ods_ir.VectorType")
-      .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
-      .Case("::mlir::FloatType", "_ods_ir.FloatType")
-      .Case("::mlir::IndexType", "_ods_ir.IndexType")
-      .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
-      .Case("::mlir::TupleType", "_ods_ir.TupleType")
-      .Case("::mlir::NoneType", "_ods_ir.NoneType")
-      .Default(StringRef());
+/// Returns the Python type annotation for a given type constraint.
+/// Returns empty StringRef if no mapping is known.
+static StringRef getPythonType(const tblgen::TypeConstraint &constraint) {
+  auto it = pythonTypeMap.find(constraint.getCppType());
+  if (it != pythonTypeMap.end())
+    return it->second;
+  return StringRef();
 }
 
 /// Emits accessors to "elements" of an Op definition. Currently, the supported
@@ -447,7 +513,7 @@ static void emitElementAccessors(
         continue;
       std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                            : "_ods_ir.OpResult";
-      if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+      if (StringRef pythonType = getPythonType(element.constraint);
           !pythonType.empty())
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       if (element.isVariableLength()) {
@@ -457,8 +523,7 @@ static void emitElementAccessors(
         } else {
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
                                                    : "_ods_ir.OpResultList";
-          if (StringRef pythonType =
-                  getPythonType(element.constraint.getCppType());
+          if (StringRef pythonType = getPythonType(element.constraint);
               !pythonType.empty())
             type = llvm::formatv("{0}[{1}]", type, pythonType);
           os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
@@ -500,8 +565,7 @@ static void emitElementAccessors(
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                    : "_ods_ir.OpResult";
         }
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -536,8 +600,7 @@ static void emitElementAccessors(
       if (!element.isVariableLength() || element.isOptional()) {
         type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                  : "_ods_ir.OpResult";
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -549,8 +612,7 @@ static void emitElementAccessors(
               formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
         }
       } else {
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -654,36 +716,13 @@ static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
   return "Attribute";
 }
 
-/// Returns the Python raw value type accepted by the AttrBuilder for the given
+/// Returns the Python value type accepted by the AttrBuilder for the given
 /// attribute. Returns empty StringRef if no mapping is known.
-static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
-  return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
-      .Cases({"BoolAttr", "I1Attr"}, "bool")
-      .Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
-      .Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
-      .Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
-      .Case("IndexAttr", "int")
-      .Cases({"F32Attr", "F64Attr"}, "float")
-      .Cases({"StrAttr", "SymbolNameAttr"}, "str")
-      .Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
-      .Case("TypeAttr", "_ods_ir.Type")
-      .Case("AffineMapAttr", "_ods_ir.AffineMap")
-      .Case("IntegerSetAttr", "_ods_ir.IntegerSet")
-      .Case("DictionaryAttr", "dict")
-      .Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
-      .Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
-             "_Sequence[int]")
-      .Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
-      .Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
-      .Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
-      .Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
-              "DenseI64ArrayAttr"},
-             "_Sequence[int]")
-      .Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
-      .Cases({"I32ElementsAttr", "I64ElementsAttr", "IndexElementsAttr"},
-             "_Union[_Sequence[int], _Buffer]")
-      .Case("F64ElementsAttr", "_Union[_Sequence[float], _Buffer]")
-      .Default(StringRef());
+static StringRef getPythonAttrType(mlir::tblgen::Attribute attr) {
+  auto it = pythonAttrTypeMap.find(attr.getAttrDefName());
+  if (it != pythonAttrTypeMap.end())
+    return it->second;
+  return StringRef();
 }
 
 /// Emits accessors to Op attributes.
@@ -1180,7 +1219,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
         argTypes[idx] = "bool";
       } else {
         std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
-        StringRef rawType = getPythonAttrRawType(nattr->attr);
+        StringRef rawType = getPythonAttrType(nattr->attr);
         argTypes[idx] =
             llvm::formatv("_Union[{0}, {1}]",
                           rawType.empty() ? "_Any" : rawType, attrType)
@@ -1189,7 +1228,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
     } else if (auto *ntype =
                    llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
       std::string type = "_ods_ir.Value";
-      if (StringRef pythonType = getPythonType(ntype->constraint.getCppType());
+      if (StringRef pythonType = getPythonType(ntype->constraint);
           !pythonType.empty()) {
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       }
@@ -1379,8 +1418,7 @@ static void emitValueBuilder(const Operator &op,
       results = ".results";
     } else if (op.getNumResults() == 1) {
       type = "_ods_ir.OpResult";
-      if (StringRef pythonType =
-              getPythonType(op.getResult(0).constraint.getCppType());
+      if (StringRef pythonType = getPythonType(op.getResult(0).constraint);
           !pythonType.empty())
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       results = ".result";
@@ -1444,6 +1482,25 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
   emitValueBuilder(op, functionArgs, os);
 }
 
+static void populateTypeMap(llvm::StringMap<std::string> &map,
+                            ArrayRef<std::pair<StringRef, StringRef>> builtins,
+                            const RecordKeeper &records, StringRef recordClass,
+                            StringRef keyField, StringRef valueField) {
+  map.clear();
+  for (auto [key, value] : builtins)
+    map[key] = value.str();
+  for (const Record *rec :
+       records.getAllDerivedDefinitionsIfDefined(recordClass)) {
+    StringRef key = rec->getValueAsString(keyField);
+    std::string value = rec->getValueAsString(valueField).str();
+    auto [it, inserted] = map.try_emplace(key, std::move(value));
+    if (!inserted && it->second != value)
+      llvm::PrintFatalError(rec->getLoc(),
+                            "conflicting " + recordClass + " for '" + key +
+                                "': '" + it->second + "' vs '" + value + "'");
+  }
+}
+
 /// Emits bindings for the dialect specified in the command line, including file
 /// headers and utilities. Returns `false` on success to comply with Tablegen
 /// registration requirements.
@@ -1451,6 +1508,11 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
   if (dialectNameStorage.empty())
     llvm::PrintFatalError("dialect name not provided");
 
+  populateTypeMap(pythonTypeMap, builtinTypeMappings, records, "PythonTypeName",
+                  "cppName", "pyName");
+  populateTypeMap(pythonAttrTypeMap, builtinAttrTypeMappings, records,
+                  "PythonAttrType", "defName", "pyType");
+
   os << fileHeader;
   if (!clDialectExtensionName.empty())
     os << formatv(dialectExtensionTemplate, dialectNameStorage);


        


More information about the Mlir-commits mailing list