[Mlir-commits] [mlir] [MLIR] [Python] Added a way to extend MLIR->Python type mappings (PR #189368)

Sergei Lebedev llvmlistbot at llvm.org
Mon Mar 30 05:51:29 PDT 2026


https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/189368

>From 6adf4eb9c83d1dd6a092abea649ed6c845bbd97d Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 30 Mar 2026 13:01:54 +0100
Subject: [PATCH] [MLIR] [Python] Added a way to extend MLIR->Python type
 mappings

The idea is to use TableGen records for both custom type constraints and
attributes:

* `PythonTypeName` is for type constraints, while
* `PythonAttrType` is for attributes.

The key types differ between these two records. `PythonTypeName` is keyed by C++
type because multiple type constraints map to the same C++ type (e.g. `I32` and
`I64` both map to `::mlir::IntegerType`), so a single entry covers all of
them. `PythonAttrType` is keyed by TableGen def name because different attributes
can share the same C++ storage type but need distinct Python
types (e.g. `I32ArrayAttr` and `StrArrayAttr` are both `::mlir::ArrayAttr`).

We could in theory reimplement `getPythonAttrName` using the same approach,
but I decided to leave it for future PRs.
---
 .../mlir/Bindings/Python/PythonBindings.td    |  31 +++
 mlir/test/mlir-tblgen/op-python-bindings.td   |  16 ++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 178 ++++++++++++------
 3 files changed, 169 insertions(+), 56 deletions(-)
 create mode 100644 mlir/include/mlir/Bindings/Python/PythonBindings.td

diff --git a/mlir/include/mlir/Bindings/Python/PythonBindings.td b/mlir/include/mlir/Bindings/Python/PythonBindings.td
new file mode 100644
index 0000000000000..b652b1381b6d6
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/PythonBindings.td
@@ -0,0 +1,31 @@
+//===-- PythonBindings.td - Python binding type mappings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C++ type and attribute to Python type mappings for -gen-python-op-bindings.
+// Dialects can include this file and add their own mappings.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TD
+#define PYTHON_BINDINGS_TD
+
+/// Maps a C++ type to a Python type annotation for operands and results.
+/// Example: PythonTypeName<"::mlir::IntegerType", "_ods_ir.IntegerType">
+class PythonTypeName<string cppType, string pythonType> {
+  string cppName = cppType;
+  string pyName = pythonType;
+}
+
+/// Maps a TableGen attribute def name to the Python type accepted by its
+/// AttrBuilder. Example: PythonAttrType<"I32Attr", "int">
+class PythonAttrType<string attrDefName, string pythonType> {
+  string defName = attrDefName;
+  string pyType = pythonType;
+}
+
+#endif // PYTHON_BINDINGS_TD
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 5e29f3f61e5c8..712c5a06c3239 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -3,6 +3,7 @@
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Bindings/Python/PythonBindings.td"
 
 // CHECK: @_ods_cext.register_dialect
 // CHECK: class _Dialect(_ods_ir.Dialect):
@@ -15,6 +16,15 @@ def Test_Dialect : Dialect {
 class TestOp<string mnemonic, list<Trait> traits = []> :
     Op<Test_Dialect, mnemonic, traits>;
 
+def TestCustomAttr
+    : DialectAttr<Test_Dialect,
+                  CPred<"::llvm::isa<::test::CustomAttr>($_self)">,
+                  "custom attribute"> {
+  let storageType = "::test::CustomAttr";
+}
+
+def : PythonAttrType<"TestCustomAttr", "object">;
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.attr_sized_operands"
@@ -217,6 +227,12 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 // CHECK:   return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class CustomAttrOp(_ods_ir.OpView):
+def CustomAttrOp : TestOp<"custom_attr"> {
+  // CHECK: def __init__(self, custom: _Union[object, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
+  let arguments = (ins TestCustomAttr:$custom);
+}
+
 // CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.default_valued_attrs"
 def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 14c76295daaee..277e12c4c6291 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -30,6 +31,78 @@ using llvm::formatv;
 using llvm::Record;
 using llvm::RecordKeeper;
 
+/// Built-in C++ type to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinTypeMappings[] = {
+    {"::mlir::MemRefType", "_ods_ir.MemRefType"},
+    {"::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType"},
+    {"::mlir::RankedTensorType", "_ods_ir.RankedTensorType"},
+    {"::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType"},
+    {"::mlir::VectorType", "_ods_ir.VectorType"},
+    {"::mlir::IntegerType", "_ods_ir.IntegerType"},
+    {"::mlir::FloatType", "_ods_ir.FloatType"},
+    {"::mlir::IndexType", "_ods_ir.IndexType"},
+    {"::mlir::ComplexType", "_ods_ir.ComplexType"},
+    {"::mlir::TupleType", "_ods_ir.TupleType"},
+    {"::mlir::NoneType", "_ods_ir.NoneType"},
+};
+
+/// Built-in TableGen attribute def name to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinAttrTypeMappings[] = {
+    {"BoolAttr", "bool"},
+    {"I1Attr", "bool"},
+    {"I8Attr", "int"},
+    {"I16Attr", "int"},
+    {"I32Attr", "int"},
+    {"I64Attr", "int"},
+    {"SI1Attr", "int"},
+    {"SI8Attr", "int"},
+    {"SI16Attr", "int"},
+    {"SI32Attr", "int"},
+    {"SI64Attr", "int"},
+    {"UI1Attr", "int"},
+    {"UI8Attr", "int"},
+    {"UI16Attr", "int"},
+    {"UI32Attr", "int"},
+    {"UI64Attr", "int"},
+    {"IndexAttr", "int"},
+    {"F32Attr", "float"},
+    {"F64Attr", "float"},
+    {"StrAttr", "str"},
+    {"SymbolNameAttr", "str"},
+    {"FlatSymbolRefAttr", "str"},
+    {"SymbolRefAttr", "str"},
+    {"TypeAttr", "_ods_ir.Type"},
+    {"AffineMapAttr", "_ods_ir.AffineMap"},
+    {"IntegerSetAttr", "_ods_ir.IntegerSet"},
+    {"DictionaryAttr", "dict"},
+    {"ArrayAttr", "_Sequence[_ods_ir.Attribute]"},
+    {"I32ArrayAttr", "_Sequence[int]"},
+    {"I64ArrayAttr", "_Sequence[int]"},
+    {"I64SmallVectorArrayAttr", "_Sequence[int]"},
+    {"F32ArrayAttr", "_Sequence[float]"},
+    {"F64ArrayAttr", "_Sequence[float]"},
+    {"BoolArrayAttr", "_Sequence[bool]"},
+    {"DenseBoolArrayAttr", "_Sequence[bool]"},
+    {"StrArrayAttr", "_Sequence[str]"},
+    {"FlatSymbolRefArrayAttr", "_Sequence[str]"},
+    {"DenseI8ArrayAttr", "_Sequence[int]"},
+    {"DenseI16ArrayAttr", "_Sequence[int]"},
+    {"DenseI32ArrayAttr", "_Sequence[int]"},
+    {"DenseI64ArrayAttr", "_Sequence[int]"},
+    {"DenseF32ArrayAttr", "_Sequence[float]"},
+    {"DenseF64ArrayAttr", "_Sequence[float]"},
+    {"I32ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"I64ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"IndexElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+    {"F64ElementsAttr", "_Union[_Sequence[float], _Buffer]"},
+};
+
+/// Maps from C++ type names to Python type annotations.
+static llvm::StringMap<std::string> pythonTypeMap;
+
+/// Maps from TableGen attribute def names to Python types.
+static llvm::StringMap<std::string> pythonAttrTypeMap;
+
 /// File header and includes.
 ///   {0} is the dialect namespace.
 constexpr const char *fileHeader = R"Py(
@@ -397,20 +470,13 @@ static std::string attrSizedTraitForKind(const char *kind) {
                  StringRef(kind).drop_front());
 }
 
-static StringRef getPythonType(StringRef cppType) {
-  return llvm::StringSwitch<StringRef>(cppType)
-      .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
-      .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
-      .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
-      .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
-      .Case("::mlir::VectorType", "_ods_ir.VectorType")
-      .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
-      .Case("::mlir::FloatType", "_ods_ir.FloatType")
-      .Case("::mlir::IndexType", "_ods_ir.IndexType")
-      .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
-      .Case("::mlir::TupleType", "_ods_ir.TupleType")
-      .Case("::mlir::NoneType", "_ods_ir.NoneType")
-      .Default(StringRef());
+/// Returns the Python type annotation for a given type constraint.
+/// Returns empty StringRef if no mapping is known.
+static StringRef getPythonType(const tblgen::TypeConstraint &constraint) {
+  auto it = pythonTypeMap.find(constraint.getCppType());
+  if (it != pythonTypeMap.end())
+    return it->second;
+  return StringRef();
 }
 
 /// Emits accessors to "elements" of an Op definition. Currently, the supported
@@ -447,7 +513,7 @@ static void emitElementAccessors(
         continue;
       std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                            : "_ods_ir.OpResult";
-      if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+      if (StringRef pythonType = getPythonType(element.constraint);
           !pythonType.empty())
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       if (element.isVariableLength()) {
@@ -457,8 +523,7 @@ static void emitElementAccessors(
         } else {
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
                                                    : "_ods_ir.OpResultList";
-          if (StringRef pythonType =
-                  getPythonType(element.constraint.getCppType());
+          if (StringRef pythonType = getPythonType(element.constraint);
               !pythonType.empty())
             type = llvm::formatv("{0}[{1}]", type, pythonType);
           os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
@@ -500,8 +565,7 @@ static void emitElementAccessors(
           type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                    : "_ods_ir.OpResult";
         }
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -536,8 +600,7 @@ static void emitElementAccessors(
       if (!element.isVariableLength() || element.isOptional()) {
         type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                  : "_ods_ir.OpResult";
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -549,8 +612,7 @@ static void emitElementAccessors(
               formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
         }
       } else {
-        if (StringRef pythonType =
-                getPythonType(element.constraint.getCppType());
+        if (StringRef pythonType = getPythonType(element.constraint);
             !pythonType.empty()) {
           type = llvm::formatv("{0}[{1}]", type, pythonType);
         }
@@ -654,36 +716,13 @@ static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
   return "Attribute";
 }
 
-/// Returns the Python raw value type accepted by the AttrBuilder for the given
+/// Returns the Python value type accepted by the AttrBuilder for the given
 /// attribute. Returns empty StringRef if no mapping is known.
-static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
-  return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
-      .Cases({"BoolAttr", "I1Attr"}, "bool")
-      .Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
-      .Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
-      .Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
-      .Case("IndexAttr", "int")
-      .Cases({"F32Attr", "F64Attr"}, "float")
-      .Cases({"StrAttr", "SymbolNameAttr"}, "str")
-      .Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
-      .Case("TypeAttr", "_ods_ir.Type")
-      .Case("AffineMapAttr", "_ods_ir.AffineMap")
-      .Case("IntegerSetAttr", "_ods_ir.IntegerSet")
-      .Case("DictionaryAttr", "dict")
-      .Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
-      .Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
-             "_Sequence[int]")
-      .Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
-      .Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
-      .Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
-      .Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
-              "DenseI64ArrayAttr"},
-             "_Sequence[int]")
-      .Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
-      .Cases({"I32ElementsAttr", "I64ElementsAttr", "IndexElementsAttr"},
-             "_Union[_Sequence[int], _Buffer]")
-      .Case("F64ElementsAttr", "_Union[_Sequence[float], _Buffer]")
-      .Default(StringRef());
+static StringRef getPythonAttrType(mlir::tblgen::Attribute attr) {
+  auto it = pythonAttrTypeMap.find(attr.getAttrDefName());
+  if (it != pythonAttrTypeMap.end())
+    return it->second;
+  return StringRef();
 }
 
 /// Emits accessors to Op attributes.
@@ -1180,7 +1219,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
         argTypes[idx] = "bool";
       } else {
         std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
-        StringRef rawType = getPythonAttrRawType(nattr->attr);
+        StringRef rawType = getPythonAttrType(nattr->attr);
         argTypes[idx] =
             llvm::formatv("_Union[{0}, {1}]",
                           rawType.empty() ? "_Any" : rawType, attrType)
@@ -1189,7 +1228,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
     } else if (auto *ntype =
                    llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
       std::string type = "_ods_ir.Value";
-      if (StringRef pythonType = getPythonType(ntype->constraint.getCppType());
+      if (StringRef pythonType = getPythonType(ntype->constraint);
           !pythonType.empty()) {
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       }
@@ -1379,8 +1418,7 @@ static void emitValueBuilder(const Operator &op,
       results = ".results";
     } else if (op.getNumResults() == 1) {
       type = "_ods_ir.OpResult";
-      if (StringRef pythonType =
-              getPythonType(op.getResult(0).constraint.getCppType());
+      if (StringRef pythonType = getPythonType(op.getResult(0).constraint);
           !pythonType.empty())
         type = llvm::formatv("{0}[{1}]", type, pythonType);
       results = ".result";
@@ -1451,6 +1489,34 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
   if (dialectNameStorage.empty())
     llvm::PrintFatalError("dialect name not provided");
 
+  pythonTypeMap.clear();
+  for (auto [key, value] : builtinTypeMappings)
+    pythonTypeMap[key] = value.str();
+  for (const Record *rec :
+       records.getAllDerivedDefinitionsIfDefined("PythonTypeName")) {
+    StringRef key = rec->getValueAsString("cppName");
+    std::string value = rec->getValueAsString("pyName").str();
+    auto [it, inserted] = pythonTypeMap.try_emplace(key, std::move(value));
+    if (!inserted && it->second != value)
+      llvm::PrintFatalError(rec->getLoc(), "conflicting PythonTypeName for '" +
+                                               key + "': '" + it->second +
+                                               "' vs '" + value + "'");
+  }
+
+  pythonAttrTypeMap.clear();
+  for (auto [key, value] : builtinAttrTypeMappings)
+    pythonAttrTypeMap[key] = value.str();
+  for (const Record *rec :
+       records.getAllDerivedDefinitionsIfDefined("PythonAttrType")) {
+    StringRef key = rec->getValueAsString("defName");
+    std::string value = rec->getValueAsString("pyType").str();
+    auto [it, inserted] = pythonAttrTypeMap.try_emplace(key, std::move(value));
+    if (!inserted && it->second != value)
+      llvm::PrintFatalError(rec->getLoc(), "conflicting PythonAttrType for '" +
+                                               key + "': '" + it->second +
+                                               "' vs '" + value + "'");
+  }
+
   os << fileHeader;
   if (!clDialectExtensionName.empty())
     os << formatv(dialectExtensionTemplate, dialectNameStorage);



More information about the Mlir-commits mailing list