[Mlir-commits] [mlir] b544ad5 - [MLIR] [Python] Added a way to extend MLIR->Python type mappings (#189368)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 31 03:00:46 PDT 2026
Author: Sergei Lebedev
Date: 2026-03-31T11:00:40+01:00
New Revision: b544ad57039588d0fe24a1f512202cc5c0bd3a67
URL: https://github.com/llvm/llvm-project/commit/b544ad57039588d0fe24a1f512202cc5c0bd3a67
DIFF: https://github.com/llvm/llvm-project/commit/b544ad57039588d0fe24a1f512202cc5c0bd3a67.diff
LOG: [MLIR] [Python] Added a way to extend MLIR->Python type mappings (#189368)
The idea is to use TableGen records for both custom type constraints and
attributes:
* `PythonTypeName` is for type constraints, while
* `PythonAttrType` is for attributes.
The key types differ between these two records. `PythonTypeName` is
keyed by C++ type because multiple type constraints map to the same C++
type (e.g. `I32` and `I64` both map to `::mlir::IntegerType`), so a
single entry covers all of them. `PythonAttrType` is keyed by TableGen
def name because different attributes can share the same C++ storage
type but need distinct Python types (e.g. `I32ArrayAttr` and
`StrArrayAttr` are both `::mlir::ArrayAttr`).
We could in theory reimplement `getPythonAttrName` using the same
approach, but I decided to leave it for future PRs.
Added:
mlir/include/mlir/Bindings/Python/PythonBindings.td
Modified:
mlir/docs/Bindings/Python.md
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index df186e43219c2..c43181859f968 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1180,6 +1180,29 @@ filled with `import`s from the generated files to enable `import
mlir.dialects.<dialect-namespace>` in Python.
+#### Customizing Type Annotations
+
+The generated `__init__` methods include type annotations for operand and
+attribute arguments. Built-in mappings cover standard MLIR types and attributes,
+but dialects can extend them by adding definitions from
+[`PythonBindings.td`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PythonBindings.td)
+to the same `.td` file passed to `mlir-tblgen -gen-python-op-bindings`:
+
+```tablegen
+include "mlir/Bindings/Python/PythonBindings.td"
+
+// Operand/result annotations: maps a C++ type from the ODS type constraint's
+// cppClassName to a Python type annotation, e.g.
+// `ir.Value` -> `ir.Value[my_dialect.MyTensorType]`.
+def : PythonTypeName<"::my_dialect::MyTensorType",
+ "my_dialect.MyTensorType">;
+
+// Attribute annotations: maps a TableGen attribute def name to the Python
+// type accepted by its AttrBuilder, e.g.
+// `Union[Any, ir.Attribute]` -> `Union[my_dialect.MyValue, ir.Attribute]`.
+def : PythonAttrType<"MyCustomAttr", "my_dialect.MyValue">;
+```
+
### Attributes and Types
Dialect attributes and types are provided in Python as subclasses of the
diff --git a/mlir/include/mlir/Bindings/Python/PythonBindings.td b/mlir/include/mlir/Bindings/Python/PythonBindings.td
new file mode 100644
index 0000000000000..b652b1381b6d6
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/PythonBindings.td
@@ -0,0 +1,31 @@
+//===-- PythonBindings.td - Python binding type mappings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C++ type and attribute to Python type mappings for -gen-python-op-bindings.
+// Dialects can include this file and add their own mappings.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TD
+#define PYTHON_BINDINGS_TD
+
+/// Maps a C++ type to a Python type annotation for operands and results.
+/// Example: PythonTypeName<"::mlir::IntegerType", "_ods_ir.IntegerType">
+class PythonTypeName<string cppType, string pythonType> {
+ string cppName = cppType;
+ string pyName = pythonType;
+}
+
+/// Maps a TableGen attribute def name to the Python type accepted by its
+/// AttrBuilder. Example: PythonAttrType<"I32Attr", "int">
+class PythonAttrType<string attrDefName, string pythonType> {
+ string defName = attrDefName;
+ string pyType = pythonType;
+}
+
+#endif // PYTHON_BINDINGS_TD
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 5e29f3f61e5c8..f566290915dde 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -3,6 +3,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Bindings/Python/PythonBindings.td"
// CHECK: @_ods_cext.register_dialect
// CHECK: class _Dialect(_ods_ir.Dialect):
@@ -15,6 +16,15 @@ def Test_Dialect : Dialect {
class TestOp<string mnemonic, list<Trait> traits = []> :
Op<Test_Dialect, mnemonic, traits>;
+def TestCustomAttr
+ : DialectAttr<Test_Dialect,
+ CPred<"::llvm::isa<::test::CustomAttr>($_self)">,
+ "custom attribute"> {
+ let storageType = "::test::CustomAttr";
+}
+
+def : PythonAttrType<"TestCustomAttr", "test.CustomType">;
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.attr_sized_operands"
@@ -217,6 +227,12 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class CustomAttrOp(_ods_ir.OpView):
+def CustomAttrOp : TestOp<"custom_attr"> {
+ // CHECK: def __init__(self, custom: _Union[test.CustomType, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
+ let arguments = (ins TestCustomAttr:$custom);
+}
+
// CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.default_valued_attrs"
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 14500edca5bcf..81c598ebbef0a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -17,6 +17,7 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
@@ -30,6 +31,78 @@ using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
+/// Built-in C++ type to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinTypeMappings[] = {
+ {"::mlir::MemRefType", "_ods_ir.MemRefType"},
+ {"::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType"},
+ {"::mlir::RankedTensorType", "_ods_ir.RankedTensorType"},
+ {"::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType"},
+ {"::mlir::VectorType", "_ods_ir.VectorType"},
+ {"::mlir::IntegerType", "_ods_ir.IntegerType"},
+ {"::mlir::FloatType", "_ods_ir.FloatType"},
+ {"::mlir::IndexType", "_ods_ir.IndexType"},
+ {"::mlir::ComplexType", "_ods_ir.ComplexType"},
+ {"::mlir::TupleType", "_ods_ir.TupleType"},
+ {"::mlir::NoneType", "_ods_ir.NoneType"},
+};
+
+/// Built-in TableGen attribute def name to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinAttrTypeMappings[] = {
+ {"BoolAttr", "bool"},
+ {"I1Attr", "bool"},
+ {"I8Attr", "int"},
+ {"I16Attr", "int"},
+ {"I32Attr", "int"},
+ {"I64Attr", "int"},
+ {"SI1Attr", "int"},
+ {"SI8Attr", "int"},
+ {"SI16Attr", "int"},
+ {"SI32Attr", "int"},
+ {"SI64Attr", "int"},
+ {"UI1Attr", "int"},
+ {"UI8Attr", "int"},
+ {"UI16Attr", "int"},
+ {"UI32Attr", "int"},
+ {"UI64Attr", "int"},
+ {"IndexAttr", "int"},
+ {"F32Attr", "float"},
+ {"F64Attr", "float"},
+ {"StrAttr", "str"},
+ {"SymbolNameAttr", "str"},
+ {"FlatSymbolRefAttr", "str"},
+ {"SymbolRefAttr", "str"},
+ {"TypeAttr", "_ods_ir.Type"},
+ {"AffineMapAttr", "_ods_ir.AffineMap"},
+ {"IntegerSetAttr", "_ods_ir.IntegerSet"},
+ {"DictionaryAttr", "dict"},
+ {"ArrayAttr", "_Sequence[_ods_ir.Attribute]"},
+ {"I32ArrayAttr", "_Sequence[int]"},
+ {"I64ArrayAttr", "_Sequence[int]"},
+ {"I64SmallVectorArrayAttr", "_Sequence[int]"},
+ {"F32ArrayAttr", "_Sequence[float]"},
+ {"F64ArrayAttr", "_Sequence[float]"},
+ {"BoolArrayAttr", "_Sequence[bool]"},
+ {"DenseBoolArrayAttr", "_Sequence[bool]"},
+ {"StrArrayAttr", "_Sequence[str]"},
+ {"FlatSymbolRefArrayAttr", "_Sequence[str]"},
+ {"DenseI8ArrayAttr", "_Sequence[int]"},
+ {"DenseI16ArrayAttr", "_Sequence[int]"},
+ {"DenseI32ArrayAttr", "_Sequence[int]"},
+ {"DenseI64ArrayAttr", "_Sequence[int]"},
+ {"DenseF32ArrayAttr", "_Sequence[float]"},
+ {"DenseF64ArrayAttr", "_Sequence[float]"},
+ {"I32ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"I64ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"IndexElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"F64ElementsAttr", "_Union[_Sequence[float], _Buffer]"},
+};
+
+/// Maps from C++ type names to Python type annotations.
+static llvm::StringMap<std::string> pythonTypeMap;
+
+/// Maps from TableGen attribute def names to Python types.
+static llvm::StringMap<std::string> pythonAttrTypeMap;
+
/// File header and includes.
/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
@@ -397,20 +470,13 @@ static std::string attrSizedTraitForKind(const char *kind) {
StringRef(kind).drop_front());
}
-static StringRef getPythonType(StringRef cppType) {
- return llvm::StringSwitch<StringRef>(cppType)
- .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
- .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
- .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
- .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
- .Case("::mlir::VectorType", "_ods_ir.VectorType")
- .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
- .Case("::mlir::FloatType", "_ods_ir.FloatType")
- .Case("::mlir::IndexType", "_ods_ir.IndexType")
- .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
- .Case("::mlir::TupleType", "_ods_ir.TupleType")
- .Case("::mlir::NoneType", "_ods_ir.NoneType")
- .Default(StringRef());
+/// Returns the Python type annotation for a given type constraint.
+/// Returns empty StringRef if no mapping is known.
+static StringRef getPythonType(const tblgen::TypeConstraint &constraint) {
+ auto it = pythonTypeMap.find(constraint.getCppType());
+ if (it != pythonTypeMap.end())
+ return it->second;
+ return StringRef();
}
/// Emits accessors to "elements" of an Op definition. Currently, the supported
@@ -447,7 +513,7 @@ static void emitElementAccessors(
continue;
std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
- if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
@@ -457,8 +523,7 @@ static void emitElementAccessors(
} else {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
: "_ods_ir.OpResultList";
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
@@ -500,8 +565,7 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -536,8 +600,7 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -549,8 +612,7 @@ static void emitElementAccessors(
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
}
} else {
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -654,36 +716,13 @@ static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
return "Attribute";
}
-/// Returns the Python raw value type accepted by the AttrBuilder for the given
+/// Returns the Python value type accepted by the AttrBuilder for the given
/// attribute. Returns empty StringRef if no mapping is known.
-static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
- return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
- .Cases({"BoolAttr", "I1Attr"}, "bool")
- .Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
- .Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
- .Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
- .Case("IndexAttr", "int")
- .Cases({"F32Attr", "F64Attr"}, "float")
- .Cases({"StrAttr", "SymbolNameAttr"}, "str")
- .Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
- .Case("TypeAttr", "_ods_ir.Type")
- .Case("AffineMapAttr", "_ods_ir.AffineMap")
- .Case("IntegerSetAttr", "_ods_ir.IntegerSet")
- .Case("DictionaryAttr", "dict")
- .Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
- .Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
- "_Sequence[int]")
- .Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
- .Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
- .Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
- .Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
- "DenseI64ArrayAttr"},
- "_Sequence[int]")
- .Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
- .Cases({"I32ElementsAttr", "I64ElementsAttr", "IndexElementsAttr"},
- "_Union[_Sequence[int], _Buffer]")
- .Case("F64ElementsAttr", "_Union[_Sequence[float], _Buffer]")
- .Default(StringRef());
+static StringRef getPythonAttrType(mlir::tblgen::Attribute attr) {
+ auto it = pythonAttrTypeMap.find(attr.getAttrDefName());
+ if (it != pythonAttrTypeMap.end())
+ return it->second;
+ return StringRef();
}
/// Emits accessors to Op attributes.
@@ -1180,7 +1219,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
argTypes[idx] = "bool";
} else {
std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
- StringRef rawType = getPythonAttrRawType(nattr->attr);
+ StringRef rawType = getPythonAttrType(nattr->attr);
argTypes[idx] =
llvm::formatv("_Union[{0}, {1}]",
rawType.empty() ? "_Any" : rawType, attrType)
@@ -1189,7 +1228,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
} else if (auto *ntype =
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
std::string type = "_ods_ir.Value";
- if (StringRef pythonType = getPythonType(ntype->constraint.getCppType());
+ if (StringRef pythonType = getPythonType(ntype->constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -1379,8 +1418,7 @@ static void emitValueBuilder(const Operator &op,
results = ".results";
} else if (op.getNumResults() == 1) {
type = "_ods_ir.OpResult";
- if (StringRef pythonType =
- getPythonType(op.getResult(0).constraint.getCppType());
+ if (StringRef pythonType = getPythonType(op.getResult(0).constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
results = ".result";
@@ -1444,6 +1482,25 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
emitValueBuilder(op, functionArgs, os);
}
+static void populateTypeMap(llvm::StringMap<std::string> &map,
+ ArrayRef<std::pair<StringRef, StringRef>> builtins,
+ const RecordKeeper &records, StringRef recordClass,
+ StringRef keyField, StringRef valueField) {
+ map.clear();
+ for (auto [key, value] : builtins)
+ map[key] = value.str();
+ for (const Record *rec :
+ records.getAllDerivedDefinitionsIfDefined(recordClass)) {
+ StringRef key = rec->getValueAsString(keyField);
+ std::string value = rec->getValueAsString(valueField).str();
+ auto [it, inserted] = map.try_emplace(key, std::move(value));
+ if (!inserted && it->second != value)
+ llvm::PrintFatalError(rec->getLoc(),
+ "conflicting " + recordClass + " for '" + key +
+ "': '" + it->second + "' vs '" + value + "'");
+ }
+}
+
/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
@@ -1451,6 +1508,11 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (dialectNameStorage.empty())
llvm::PrintFatalError("dialect name not provided");
+ populateTypeMap(pythonTypeMap, builtinTypeMappings, records, "PythonTypeName",
+ "cppName", "pyName");
+ populateTypeMap(pythonAttrTypeMap, builtinAttrTypeMappings, records,
+ "PythonAttrType", "defName", "pyType");
+
os << fileHeader;
if (!clDialectExtensionName.empty())
os << formatv(dialectExtensionTemplate, dialectNameStorage);
More information about the Mlir-commits
mailing list