[Mlir-commits] [mlir] [MLIR] [Python] Added a way to extend MLIR->Python type mappings (PR #189368)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 30 05:38:52 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Sergei Lebedev (superbobry)
<details>
<summary>Changes</summary>
The idea is to use TableGen records for both custom type constraints and attributes:
* `PythonTypeName` is for type constraints, while
* `PythonAttrType` is for attributes.
The key types differ between these two records. `PythonTypeName` is keyed by C++ type because multiple type constraints map to the same C++ type (e.g. `I32` and `I64` both map to `::mlir::IntegerType`), so a single entry covers all of them. `PythonAttrType` is keyed by TableGen def name because different attributes can share the same C++ storage type but need distinct Python types (e.g. `I32ArrayAttr` and `StrArrayAttr` are both `::mlir::ArrayAttr`).
We could in theory reimplement `getPythonAttrName` using the same approach, but I decided to leave it for future PRs.
---
Full diff: https://github.com/llvm/llvm-project/pull/189368.diff
2 Files Affected:
- (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+16)
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+122-56)
``````````diff
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 5e29f3f61e5c8..712c5a06c3239 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -3,6 +3,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Bindings/Python/PythonBindings.td"
// CHECK: @_ods_cext.register_dialect
// CHECK: class _Dialect(_ods_ir.Dialect):
@@ -15,6 +16,15 @@ def Test_Dialect : Dialect {
class TestOp<string mnemonic, list<Trait> traits = []> :
Op<Test_Dialect, mnemonic, traits>;
+def TestCustomAttr
+ : DialectAttr<Test_Dialect,
+ CPred<"::llvm::isa<::test::CustomAttr>($_self)">,
+ "custom attribute"> {
+ let storageType = "::test::CustomAttr";
+}
+
+def : PythonAttrType<"TestCustomAttr", "object">;
+
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.attr_sized_operands"
@@ -217,6 +227,12 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class CustomAttrOp(_ods_ir.OpView):
+def CustomAttrOp : TestOp<"custom_attr"> {
+ // CHECK: def __init__(self, custom: _Union[object, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
+ let arguments = (ins TestCustomAttr:$custom);
+}
+
// CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.default_valued_attrs"
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 14c76295daaee..277e12c4c6291 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -17,6 +17,7 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
@@ -30,6 +31,78 @@ using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
+/// Built-in C++ type to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinTypeMappings[] = {
+ {"::mlir::MemRefType", "_ods_ir.MemRefType"},
+ {"::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType"},
+ {"::mlir::RankedTensorType", "_ods_ir.RankedTensorType"},
+ {"::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType"},
+ {"::mlir::VectorType", "_ods_ir.VectorType"},
+ {"::mlir::IntegerType", "_ods_ir.IntegerType"},
+ {"::mlir::FloatType", "_ods_ir.FloatType"},
+ {"::mlir::IndexType", "_ods_ir.IndexType"},
+ {"::mlir::ComplexType", "_ods_ir.ComplexType"},
+ {"::mlir::TupleType", "_ods_ir.TupleType"},
+ {"::mlir::NoneType", "_ods_ir.NoneType"},
+};
+
+/// Built-in TableGen attribute def name to Python type mappings.
+static constexpr std::pair<StringRef, StringRef> builtinAttrTypeMappings[] = {
+ {"BoolAttr", "bool"},
+ {"I1Attr", "bool"},
+ {"I8Attr", "int"},
+ {"I16Attr", "int"},
+ {"I32Attr", "int"},
+ {"I64Attr", "int"},
+ {"SI1Attr", "int"},
+ {"SI8Attr", "int"},
+ {"SI16Attr", "int"},
+ {"SI32Attr", "int"},
+ {"SI64Attr", "int"},
+ {"UI1Attr", "int"},
+ {"UI8Attr", "int"},
+ {"UI16Attr", "int"},
+ {"UI32Attr", "int"},
+ {"UI64Attr", "int"},
+ {"IndexAttr", "int"},
+ {"F32Attr", "float"},
+ {"F64Attr", "float"},
+ {"StrAttr", "str"},
+ {"SymbolNameAttr", "str"},
+ {"FlatSymbolRefAttr", "str"},
+ {"SymbolRefAttr", "str"},
+ {"TypeAttr", "_ods_ir.Type"},
+ {"AffineMapAttr", "_ods_ir.AffineMap"},
+ {"IntegerSetAttr", "_ods_ir.IntegerSet"},
+ {"DictionaryAttr", "dict"},
+ {"ArrayAttr", "_Sequence[_ods_ir.Attribute]"},
+ {"I32ArrayAttr", "_Sequence[int]"},
+ {"I64ArrayAttr", "_Sequence[int]"},
+ {"I64SmallVectorArrayAttr", "_Sequence[int]"},
+ {"F32ArrayAttr", "_Sequence[float]"},
+ {"F64ArrayAttr", "_Sequence[float]"},
+ {"BoolArrayAttr", "_Sequence[bool]"},
+ {"DenseBoolArrayAttr", "_Sequence[bool]"},
+ {"StrArrayAttr", "_Sequence[str]"},
+ {"FlatSymbolRefArrayAttr", "_Sequence[str]"},
+ {"DenseI8ArrayAttr", "_Sequence[int]"},
+ {"DenseI16ArrayAttr", "_Sequence[int]"},
+ {"DenseI32ArrayAttr", "_Sequence[int]"},
+ {"DenseI64ArrayAttr", "_Sequence[int]"},
+ {"DenseF32ArrayAttr", "_Sequence[float]"},
+ {"DenseF64ArrayAttr", "_Sequence[float]"},
+ {"I32ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"I64ElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"IndexElementsAttr", "_Union[_Sequence[int], _Buffer]"},
+ {"F64ElementsAttr", "_Union[_Sequence[float], _Buffer]"},
+};
+
+/// Maps from C++ type names to Python type annotations.
+static llvm::StringMap<std::string> pythonTypeMap;
+
+/// Maps from TableGen attribute def names to Python types.
+static llvm::StringMap<std::string> pythonAttrTypeMap;
+
/// File header and includes.
/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
@@ -397,20 +470,13 @@ static std::string attrSizedTraitForKind(const char *kind) {
StringRef(kind).drop_front());
}
-static StringRef getPythonType(StringRef cppType) {
- return llvm::StringSwitch<StringRef>(cppType)
- .Case("::mlir::MemRefType", "_ods_ir.MemRefType")
- .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
- .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
- .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
- .Case("::mlir::VectorType", "_ods_ir.VectorType")
- .Case("::mlir::IntegerType", "_ods_ir.IntegerType")
- .Case("::mlir::FloatType", "_ods_ir.FloatType")
- .Case("::mlir::IndexType", "_ods_ir.IndexType")
- .Case("::mlir::ComplexType", "_ods_ir.ComplexType")
- .Case("::mlir::TupleType", "_ods_ir.TupleType")
- .Case("::mlir::NoneType", "_ods_ir.NoneType")
- .Default(StringRef());
+/// Returns the Python type annotation for a given type constraint.
+/// Returns empty StringRef if no mapping is known.
+static StringRef getPythonType(const tblgen::TypeConstraint &constraint) {
+ auto it = pythonTypeMap.find(constraint.getCppType());
+ if (it != pythonTypeMap.end())
+ return it->second;
+ return StringRef();
}
/// Emits accessors to "elements" of an Op definition. Currently, the supported
@@ -447,7 +513,7 @@ static void emitElementAccessors(
continue;
std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
- if (StringRef pythonType = getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
@@ -457,8 +523,7 @@ static void emitElementAccessors(
} else {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
: "_ods_ir.OpResultList";
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
@@ -500,8 +565,7 @@ static void emitElementAccessors(
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
}
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -536,8 +600,7 @@ static void emitElementAccessors(
if (!element.isVariableLength() || element.isOptional()) {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
: "_ods_ir.OpResult";
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -549,8 +612,7 @@ static void emitElementAccessors(
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
}
} else {
- if (StringRef pythonType =
- getPythonType(element.constraint.getCppType());
+ if (StringRef pythonType = getPythonType(element.constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -654,36 +716,13 @@ static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
return "Attribute";
}
-/// Returns the Python raw value type accepted by the AttrBuilder for the given
+/// Returns the Python value type accepted by the AttrBuilder for the given
/// attribute. Returns empty StringRef if no mapping is known.
-static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
- return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
- .Cases({"BoolAttr", "I1Attr"}, "bool")
- .Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
- .Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
- .Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
- .Case("IndexAttr", "int")
- .Cases({"F32Attr", "F64Attr"}, "float")
- .Cases({"StrAttr", "SymbolNameAttr"}, "str")
- .Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
- .Case("TypeAttr", "_ods_ir.Type")
- .Case("AffineMapAttr", "_ods_ir.AffineMap")
- .Case("IntegerSetAttr", "_ods_ir.IntegerSet")
- .Case("DictionaryAttr", "dict")
- .Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
- .Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
- "_Sequence[int]")
- .Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
- .Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
- .Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
- .Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
- "DenseI64ArrayAttr"},
- "_Sequence[int]")
- .Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
- .Cases({"I32ElementsAttr", "I64ElementsAttr", "IndexElementsAttr"},
- "_Union[_Sequence[int], _Buffer]")
- .Case("F64ElementsAttr", "_Union[_Sequence[float], _Buffer]")
- .Default(StringRef());
+static StringRef getPythonAttrType(mlir::tblgen::Attribute attr) {
+ auto it = pythonAttrTypeMap.find(attr.getAttrDefName());
+ if (it != pythonAttrTypeMap.end())
+ return it->second;
+ return StringRef();
}
/// Emits accessors to Op attributes.
@@ -1180,7 +1219,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
argTypes[idx] = "bool";
} else {
std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
- StringRef rawType = getPythonAttrRawType(nattr->attr);
+ StringRef rawType = getPythonAttrType(nattr->attr);
argTypes[idx] =
llvm::formatv("_Union[{0}, {1}]",
rawType.empty() ? "_Any" : rawType, attrType)
@@ -1189,7 +1228,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
} else if (auto *ntype =
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
std::string type = "_ods_ir.Value";
- if (StringRef pythonType = getPythonType(ntype->constraint.getCppType());
+ if (StringRef pythonType = getPythonType(ntype->constraint);
!pythonType.empty()) {
type = llvm::formatv("{0}[{1}]", type, pythonType);
}
@@ -1379,8 +1418,7 @@ static void emitValueBuilder(const Operator &op,
results = ".results";
} else if (op.getNumResults() == 1) {
type = "_ods_ir.OpResult";
- if (StringRef pythonType =
- getPythonType(op.getResult(0).constraint.getCppType());
+ if (StringRef pythonType = getPythonType(op.getResult(0).constraint);
!pythonType.empty())
type = llvm::formatv("{0}[{1}]", type, pythonType);
results = ".result";
@@ -1451,6 +1489,34 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (dialectNameStorage.empty())
llvm::PrintFatalError("dialect name not provided");
+ pythonTypeMap.clear();
+ for (auto [key, value] : builtinTypeMappings)
+ pythonTypeMap[key] = value.str();
+ for (const Record *rec :
+ records.getAllDerivedDefinitionsIfDefined("PythonTypeName")) {
+ StringRef key = rec->getValueAsString("cppName");
+ std::string value = rec->getValueAsString("pyName").str();
+ auto [it, inserted] = pythonTypeMap.try_emplace(key, std::move(value));
+ if (!inserted && it->second != value)
+ llvm::PrintFatalError(rec->getLoc(), "conflicting PythonTypeName for '" +
+ key + "': '" + it->second +
+ "' vs '" + value + "'");
+ }
+
+ pythonAttrTypeMap.clear();
+ for (auto [key, value] : builtinAttrTypeMappings)
+ pythonAttrTypeMap[key] = value.str();
+ for (const Record *rec :
+ records.getAllDerivedDefinitionsIfDefined("PythonAttrType")) {
+ StringRef key = rec->getValueAsString("defName");
+ std::string value = rec->getValueAsString("pyType").str();
+ auto [it, inserted] = pythonAttrTypeMap.try_emplace(key, std::move(value));
+ if (!inserted && it->second != value)
+ llvm::PrintFatalError(rec->getLoc(), "conflicting PythonAttrType for '" +
+ key + "': '" + it->second +
+ "' vs '" + value + "'");
+ }
+
os << fileHeader;
if (!clDialectExtensionName.empty())
os << formatv(dialectExtensionTemplate, dialectNameStorage);
``````````
</details>
https://github.com/llvm/llvm-project/pull/189368
More information about the Mlir-commits
mailing list