[Mlir-commits] [mlir] [MLIR][Python] Add python-side adaptor class codegen in mlir-tblgen (PR #176640)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 18 03:05:42 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/176640
🚧 This is experimental and working in process. 🚧
In dialect conversion, the operation adaptor is a fairly important helper type. Before we introduce the dialect conversion API into the MLIR Python bindings, we first need to generate the corresponding adaptor class for each op type (i.e., opview subclasses). This PR is an attempt to add support for generating adaptor classes in mlir-tblgen’s Python op class generator.
>From 36e3ff81cd631d2ba07b536426a32622d6c842cd Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 18 Jan 2026 18:59:35 +0800
Subject: [PATCH] [MLIR][Python] Add python-side adaptor class codegen in
mlir-tblgen
---
mlir/python/mlir/dialects/_ods_common.py | 6 ++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 71 ++++++++++++-------
2 files changed, 53 insertions(+), 24 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 10abd06ff266e..e8b7aa81ef920 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -305,3 +305,9 @@ def _get_int_array_array_attr(
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
+
+
+class OpAdaptor:
+ def __init__(self, operands, attributes) -> None:
+ self.operands = operands
+ self.attributes = attributes
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6545559ff1b10..6571db1796010 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -40,6 +40,7 @@ from ._ods_common import (
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
+ OpAdaptor as _ods_OpAdaptor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
@@ -69,6 +70,15 @@ constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
class {0}(_ods_ir.OpView):{2}
OPERATION_NAME = "{1}"
+ Adaptor = {0}Adaptor
+)Py";
+
+/// Template for operation class:
+/// {0} is the Python class name;
+/// {1} is the operation name。
+constexpr const char *opAdaptorClassTemplate = R"Py(
+class {0}Adaptor(_ods_OpAdaptor):
+ OPERATION_NAME = "{1}"
)Py";
/// Template for class level declarations of operand and result
@@ -99,7 +109,7 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
def {0}(self) -> {3}:
- return self.operation.{1}s[{2}]
+ return self.{1}s[{2}]
)Py";
/// Template for single-element accessor after a variable-length group:
@@ -113,8 +123,8 @@ constexpr const char *opSingleTemplate = R"Py(
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
- _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
- return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
+ _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+ return self.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
/// Template for an optional element accessor:
@@ -129,7 +139,7 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self) -> _Optional[{4}]:
- return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
+ return None if len(self.{1}s) < {2} else self.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
@@ -141,8 +151,8 @@ constexpr const char *opOneOptionalTemplate = R"Py(
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
- _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
- return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
+ _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+ return self.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
/// First part of the template for equally-sized variadic group accessor:
@@ -156,20 +166,20 @@ constexpr const char *opOneVariadicTemplate = R"Py(
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self) -> {6}:
- start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
+ start, elements_per_group = _ods_equally_sized_accessor(self.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
- return self.operation.{0}s[start]
+ return self.{0}s[start]
)Py";
/// Second part of the template for equally-sized case, accessing a variadic
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
- return self.operation.{0}s[start:start + elements_per_group]
+ return self.{0}s[start:start + elements_per_group]
)Py";
/// Template for an attribute-sized group accessor:
@@ -177,14 +187,15 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
-/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
-/// {4} is the type hint.
+/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional);
+/// {4} is the type hint;
+/// {5} is the instance variable name in python.
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
{1}_range = _ods_segmented_accessor(
- self.operation.{1}s,
- self.operation.attributes["{1}SegmentSizes"], {2})
+ self.{5}s,
+ self.attributes["{1}SegmentSizes"], {2})
return {1}_range{3}
)Py";
@@ -364,7 +375,8 @@ static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
- getElement) {
+ getElement,
+ bool isAdaptor = false) {
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
kind) &&
"unsupported kind");
@@ -375,6 +387,8 @@ static void emitElementAccessors(
StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
+ std::string pyAttrName = isAdaptor ? kind : std::string("operation.") + kind;
+
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
@@ -393,20 +407,20 @@ static void emitElementAccessors(
type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
if (element.isOptional()) {
- os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
- numElements, i, type);
+ os << formatv(opOneOptionalTemplate, sanitizeName(element.name),
+ pyAttrName, numElements, i, type);
} else {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
: "_ods_ir.OpResultList";
- os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
- numElements, i, type);
+ os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
+ pyAttrName, numElements, i, type);
}
} else if (seenVariableLength) {
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
- kind, numElements, i, type);
+ pyAttrName, numElements, i, type);
} else {
- os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
- type);
+ os << formatv(opSingleTemplate, sanitizeName(element.name), pyAttrName,
+ i, type);
}
}
return;
@@ -444,12 +458,12 @@ static void emitElementAccessors(
type += "[" + pythonType.str() + "]";
}
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
- kind, numSimpleLength, numVariadicGroups,
+ pyAttrName, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
- kind);
+ pyAttrName);
}
if (element.isVariableLength())
++numPrecedingVariadic;
@@ -490,7 +504,7 @@ static void emitElementAccessors(
}
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
- i, trailing, type);
+ i, trailing, type, pyAttrName);
}
return;
}
@@ -1193,8 +1207,17 @@ static std::string makeDocStringForOp(const Operator &op) {
return docString;
}
+static void emitAdaptorOperandAccessors(const Operator &op, raw_ostream &os) {
+ emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
+ getNumOperands(op), getOperand, /*isAdaptor=*/true);
+}
+
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
+ os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
+ op.getOperationName());
+ emitAdaptorOperandAccessors(op, os);
+
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
makeDocStringForOp(op));
More information about the Mlir-commits
mailing list