[Mlir-commits] [mlir] [MLIR][Python] Support dialect conversion in python bindings (PR #177782)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 28 06:22:12 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/177782
>From 36e3ff81cd631d2ba07b536426a32622d6c842cd Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 18 Jan 2026 18:59:35 +0800
Subject: [PATCH 01/13] [MLIR][Python] Add python-side adaptor class codegen in
mlir-tblgen
---
mlir/python/mlir/dialects/_ods_common.py | 6 ++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 71 ++++++++++++-------
2 files changed, 53 insertions(+), 24 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 10abd06ff266e..e8b7aa81ef920 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -305,3 +305,9 @@ def _get_int_array_array_attr(
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
+
+
+class OpAdaptor:
+ def __init__(self, operands, attributes) -> None:
+ self.operands = operands
+ self.attributes = attributes
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6545559ff1b10..6571db1796010 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -40,6 +40,7 @@ from ._ods_common import (
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
+ OpAdaptor as _ods_OpAdaptor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
@@ -69,6 +70,15 @@ constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
class {0}(_ods_ir.OpView):{2}
OPERATION_NAME = "{1}"
+ Adaptor = {0}Adaptor
+)Py";
+
+/// Template for operation class:
+/// {0} is the Python class name;
+/// {1} is the operation name。
+constexpr const char *opAdaptorClassTemplate = R"Py(
+class {0}Adaptor(_ods_OpAdaptor):
+ OPERATION_NAME = "{1}"
)Py";
/// Template for class level declarations of operand and result
@@ -99,7 +109,7 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
def {0}(self) -> {3}:
- return self.operation.{1}s[{2}]
+ return self.{1}s[{2}]
)Py";
/// Template for single-element accessor after a variable-length group:
@@ -113,8 +123,8 @@ constexpr const char *opSingleTemplate = R"Py(
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
- _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
- return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
+ _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+ return self.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
/// Template for an optional element accessor:
@@ -129,7 +139,7 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self) -> _Optional[{4}]:
- return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
+ return None if len(self.{1}s) < {2} else self.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
@@ -141,8 +151,8 @@ constexpr const char *opOneOptionalTemplate = R"Py(
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
- _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
- return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
+ _ods_variadic_group_length = len(self.{1}s) - {2} + 1
+ return self.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
/// First part of the template for equally-sized variadic group accessor:
@@ -156,20 +166,20 @@ constexpr const char *opOneVariadicTemplate = R"Py(
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self) -> {6}:
- start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
+ start, elements_per_group = _ods_equally_sized_accessor(self.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
- return self.operation.{0}s[start]
+ return self.{0}s[start]
)Py";
/// Second part of the template for equally-sized case, accessing a variadic
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
- return self.operation.{0}s[start:start + elements_per_group]
+ return self.{0}s[start:start + elements_per_group]
)Py";
/// Template for an attribute-sized group accessor:
@@ -177,14 +187,15 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
-/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
-/// {4} is the type hint.
+/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional);
+/// {4} is the type hint;
+/// {5} is the instance variable name in python.
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
{1}_range = _ods_segmented_accessor(
- self.operation.{1}s,
- self.operation.attributes["{1}SegmentSizes"], {2})
+ self.{5}s,
+ self.attributes["{1}SegmentSizes"], {2})
return {1}_range{3}
)Py";
@@ -364,7 +375,8 @@ static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
- getElement) {
+ getElement,
+ bool isAdaptor = false) {
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
kind) &&
"unsupported kind");
@@ -375,6 +387,8 @@ static void emitElementAccessors(
StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
+ std::string pyAttrName = isAdaptor ? kind : std::string("operation.") + kind;
+
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
@@ -393,20 +407,20 @@ static void emitElementAccessors(
type = llvm::formatv("{0}[{1}]", type, pythonType);
if (element.isVariableLength()) {
if (element.isOptional()) {
- os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
- numElements, i, type);
+ os << formatv(opOneOptionalTemplate, sanitizeName(element.name),
+ pyAttrName, numElements, i, type);
} else {
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
: "_ods_ir.OpResultList";
- os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
- numElements, i, type);
+ os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
+ pyAttrName, numElements, i, type);
}
} else if (seenVariableLength) {
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
- kind, numElements, i, type);
+ pyAttrName, numElements, i, type);
} else {
- os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
- type);
+ os << formatv(opSingleTemplate, sanitizeName(element.name), pyAttrName,
+ i, type);
}
}
return;
@@ -444,12 +458,12 @@ static void emitElementAccessors(
type += "[" + pythonType.str() + "]";
}
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
- kind, numSimpleLength, numVariadicGroups,
+ pyAttrName, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic, type);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
- kind);
+ pyAttrName);
}
if (element.isVariableLength())
++numPrecedingVariadic;
@@ -490,7 +504,7 @@ static void emitElementAccessors(
}
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
- i, trailing, type);
+ i, trailing, type, pyAttrName);
}
return;
}
@@ -1193,8 +1207,17 @@ static std::string makeDocStringForOp(const Operator &op) {
return docString;
}
+static void emitAdaptorOperandAccessors(const Operator &op, raw_ostream &os) {
+ emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
+ getNumOperands(op), getOperand, /*isAdaptor=*/true);
+}
+
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
+ os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
+ op.getOperationName());
+ emitAdaptorOperandAccessors(op, os);
+
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
makeDocStringForOp(op));
>From c7039afdd7c04000b838982b6b1751e0d160368b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 18 Jan 2026 22:15:34 +0800
Subject: [PATCH 02/13] fix test case
---
mlir/test/mlir-tblgen/op-python-bindings.td | 76 +++++++++----------
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 8 +-
2 files changed, 43 insertions(+), 41 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 929851724ba71..a4cb5fdacbe30 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -16,8 +16,8 @@ class TestOp<string mnemonic, list<Trait> traits = []> :
Op<Test_Dialect, mnemonic, traits>;
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
+// CHECK-LABEL: class AttrSizedOperandsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attr_sized_operands"
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
@@ -64,8 +64,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
+// CHECK-LABEL: class AttrSizedResultsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attr_sized_results"
// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
@@ -114,8 +114,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttributedOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
+// CHECK-LABEL: class AttributedOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attributed_op"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
@@ -164,8 +164,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
+// CHECK-LABEL: class AttributedOpWithOperands(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.attributed_op_with_operands"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
@@ -201,8 +201,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.default_valued_attrs"
+// CHECK-LABEL: class DefaultValuedAttrsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.default_valued_attrs"
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
// CHECK: operands = []
@@ -283,8 +283,8 @@ def DescriptionOp : TestOp<"description"> {
}
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class EmptyOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.empty"
+// CHECK-LABEL: class EmptyOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.empty"
def EmptyOp : TestOp<"empty">;
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
@@ -329,8 +329,8 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class MissingNamesOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
+// CHECK-LABEL: class MissingNamesOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.missing_names"
def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
// CHECK: operands = []
@@ -368,8 +368,8 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
+// CHECK-LABEL: class OneOptionalOperandOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_optional_operand"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
@@ -400,8 +400,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK-LABEL: class OneVariadicOperandOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_variadic_operand"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
@@ -433,8 +433,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
+// CHECK-LABEL: class OneVariadicResultOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.one_variadic_result"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
@@ -468,8 +468,8 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class PythonKeywordOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
+// CHECK-LABEL: class PythonKeywordOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.python_keyword"
def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: def __init__(self, in_, *, loc=None, ip=None):
// CHECK: operands = []
@@ -518,8 +518,8 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_operand"
+// CHECK-LABEL: class SameVariadicOperandSizeOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.same_variadic_operand"
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
@@ -544,8 +544,8 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
+// CHECK-LABEL: class SameVariadicResultSizeOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.same_variadic_result"
def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
@@ -571,8 +571,8 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class SimpleOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.simple"
+// CHECK-LABEL: class SimpleOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.simple"
def SimpleOp : TestOp<"simple"> {
// CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
// CHECK: operands = []
@@ -611,8 +611,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
-// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
+// CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
@@ -639,8 +639,8 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> VariadicAndNormalRegionOp:
// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
-// CHECK: class VariadicRegionOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
+// CHECK-LABEL: class VariadicRegionOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.variadic_region"
def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
@@ -664,8 +664,8 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.123with--special.characters"
+// CHECK-LABEL: class WithSpecialCharactersOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.123with--special.characters"
def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}
@@ -673,8 +673,8 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)
// CHECK: @_ods_cext.register_operation(_Dialect)
-// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
+// CHECK-LABEL: class WithSuccessorsOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.with_successors"
def WithSuccessorsOp : TestOp<"with_successors"> {
// CHECK-NOT: _ods_successors = None
// CHECK: _ods_successors = []
@@ -687,8 +687,8 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> WithSuccessorsOp:
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
-// CHECK: class snake_case(_ods_ir.OpView):
-// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
+// CHECK-LABEL: class snake_case(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.snake_case"
def already_snake_case : TestOp<"snake_case"> {}
// CHECK: def snake_case_(*, loc=None, ip=None) -> snake_case:
// CHECK: return snake_case(loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6571db1796010..6eed815591751 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -189,13 +189,14 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional);
/// {4} is the type hint;
-/// {5} is the instance variable name in python.
+/// {5} is the instance variable name in python;
+/// {6} is the instance variable name for attributes in python.
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
def {0}(self) -> {4}:
{1}_range = _ods_segmented_accessor(
self.{5}s,
- self.attributes["{1}SegmentSizes"], {2})
+ self.{6}["{1}SegmentSizes"], {2})
return {1}_range{3}
)Py";
@@ -504,7 +505,8 @@ static void emitElementAccessors(
}
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
- i, trailing, type, pyAttrName);
+ i, trailing, type, pyAttrName,
+ isAdaptor ? "attributes" : "operation.attributes");
}
return;
}
>From 0b28512ccd441638c8fdf97a1aa4c52cb4d7c279 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 19 Jan 2026 23:23:29 +0800
Subject: [PATCH 03/13] add attr getters
---
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 62 +++++++++++++++++++
1 file changed, 62 insertions(+)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6eed815591751..74f6d7edea4c2 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -228,6 +228,28 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
return self.operation.attributes["{1}"]
)Py";
+/// Template for an operation attribute getter for adaptors:
+/// {0} is the name of the attribute sanitized for Python;
+/// {1} is the original name of the attribute.
+/// {2} is the type hint.
+constexpr const char *adaptorAttributeGetterTemplate = R"Py(
+ @builtins.property
+ def {0}(self) -> {2}:
+ return self.attributes["{1}"]
+)Py";
+
+/// Template for an optional operation attribute getter for adaptors:
+/// {0} is the name of the attribute sanitized for Python;
+/// {1} is the original name of the attribute.
+/// {2} is the type hint.
+constexpr const char *adaptorOptionalAttributeGetterTemplate = R"Py(
+ @builtins.property
+ def {0}(self) -> _Optional[{2}]:
+ if "{1}" not in self.attributes:
+ return None
+ return self.attributes["{1}"]
+)Py";
+
/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
@@ -239,6 +261,17 @@ constexpr const char *unitAttributeGetterTemplate = R"Py(
return "{1}" in self.operation.attributes
)Py";
+/// Template for a getter of a unit operation attribute for adaptors, returns
+/// True of the unit attribute is present, False otherwise (unit attributes have
+/// meaning by mere presence):
+/// {0} is the name of the attribute sanitized for Python,
+/// {1} is the original name of the attribute.
+constexpr const char *adaptorUnitAttributeGetterTemplate = R"Py(
+ @builtins.property
+ def {0}(self) -> bool:
+ return "{1}" in self.attributes
+)Py";
+
/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
@@ -640,6 +673,34 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
}
}
+/// Emits accessors to Op attributes for adaptors.
+static void emitAdaptorAttributeAccessors(const Operator &op, raw_ostream &os) {
+ for (const auto &namedAttr : op.getAttributes()) {
+ // Skip "derived" attributes because they are just C++ functions that we
+ // don't currently expose.
+ if (namedAttr.attr.isDerivedAttr())
+ continue;
+
+ if (namedAttr.name.empty())
+ continue;
+
+ std::string sanitizedName = sanitizeName(namedAttr.name);
+
+ // Unit attributes are handled specially.
+ if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
+ os << formatv(adaptorUnitAttributeGetterTemplate, sanitizedName,
+ namedAttr.name);
+ continue;
+ }
+
+ std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
+ os << formatv(namedAttr.attr.isOptional()
+ ? adaptorOptionalAttributeGetterTemplate
+ : adaptorAttributeGetterTemplate,
+ sanitizedName, namedAttr.name, type);
+ }
+}
+
/// Template for the default auto-generated builder.
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
@@ -1219,6 +1280,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
op.getOperationName());
emitAdaptorOperandAccessors(op, os);
+ emitAdaptorAttributeAccessors(op, os);
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
makeDocStringForOp(op));
>From 694e813bd00ae46eae1f8e818e501ade6daa8263 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 19 Jan 2026 23:46:41 +0800
Subject: [PATCH 04/13] add test
---
mlir/test/mlir-tblgen/op-python-bindings.td | 37 ++++++++++++++++++++-
1 file changed, 36 insertions(+), 1 deletion(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index a4cb5fdacbe30..2d44824c9d8fd 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -112,10 +112,26 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
-
+// CHECK-LABEL: class AttributedOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.attributed_op"
+// CHECK: @builtins.property
+// CHECK: def i32attr(self) -> _ods_ir.IntegerAttr:
+// CHECK: return self.attributes["i32attr"]
+// CHECK: @builtins.property
+// CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
+// CHECK: if "optionalF32Attr" not in self.attributes:
+// CHECK: return None
+// CHECK: return self.attributes["optionalF32Attr"]
+// CHECK: @builtins.property
+// CHECK: def unitAttr(self) -> bool:
+// CHECK: return "unitAttr" in self.attributes
+// CHECK: @builtins.property
+// CHECK: def in_(self) -> _ods_ir.IntegerAttr:
+// CHECK: return self.attributes["in"]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class AttributedOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.attributed_op"
+// CHECK: Adaptor = AttributedOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
@@ -367,9 +383,18 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
+// CHECK-LABEL: class OneOptionalOperandOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.one_optional_operand"
+// CHECK: @builtins.property
+// CHECK: def non_optional(self) -> _ods_ir.Value:
+// CHECK: return self.operands[0]
+// CHECK: @builtins.property
+// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
+// CHECK: return None if len(self.operands) < 2 else self.operands[1]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.one_optional_operand"
+// CHECK: Adaptor = OneOptionalOperandOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
@@ -399,9 +424,19 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp:
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
+// CHECK-LABEL: class OneVariadicOperandOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK: @builtins.property
+// CHECK: def non_variadic(self) -> _ods_ir.Value:
+// CHECK: return self.operands[0]
+// CHECK: @builtins.property
+// CHECK: def variadic(self) -> _ods_ir.OpOperandList:
+// CHECK: _ods_variadic_group_length = len(self.operands) - 2 + 1
+// CHECK: return self.operands[1:1 + _ods_variadic_group_length]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK: Adaptor = OneVariadicOperandOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
>From 5b89f1e80d62918b2cf87de782a77896ed9f50cb Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 21:37:27 +0800
Subject: [PATCH 05/13] switch to register
---
mlir/include/mlir/Bindings/Python/Globals.h | 14 ++++
mlir/lib/Bindings/Python/Globals.cpp | 30 ++++++++
mlir/lib/Bindings/Python/IRCore.cpp | 22 ++++++
mlir/test/mlir-tblgen/op-python-bindings.td | 72 +++++++++----------
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 13 ++--
5 files changed, 109 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 5548a716cbe21..ad58c9374f766 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -94,6 +94,12 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
void registerOperationImpl(const std::string &operationName,
nanobind::object pyClass, bool replace = false);
+ /// Adds a operation adaptor class.
+ /// Raises an exception if the mapping already exists and replace == false.
+ /// This is intended to be called by implementation code.
+ void registerOpAdaptorImpl(const std::string &operationName,
+ nanobind::object pyClass, bool replace = false);
+
/// Returns the custom Attribute builder for Attribute kind.
std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);
@@ -117,6 +123,12 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
+ /// Looks up a registered operation adaptor class by operation
+ /// name. Note that this may trigger a load of the dialect, which can
+ /// arbitrarily re-enter.
+ std::optional<nanobind::object>
+ lookupOpAdaptorClass(llvm::StringRef operationName);
+
class MLIR_PYTHON_API_EXPORTED TracebackLoc {
public:
bool locTracebacksEnabled();
@@ -184,6 +196,8 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<nanobind::object> operationClassMap;
+ /// Map of full operation name to external operation adaptor class object.
+ llvm::StringMap<nanobind::object> opAdaptorClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index e2e8693ba45f3..3d7ee3d30656e 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -137,6 +137,18 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
found = std::move(pyClass);
}
+void PyGlobals::registerOpAdaptorImpl(const std::string &operationName,
+ nb::object pyClass, bool replace) {
+ nb::ft_lock_guard lock(mutex);
+ nb::object &found = opAdaptorClassMap[operationName];
+ if (found && !replace) {
+ throw std::runtime_error((llvm::Twine("Operation adaptor of '") +
+ operationName + "' is already registered.")
+ .str());
+ }
+ found = std::move(pyClass);
+}
+
std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
nb::ft_lock_guard lock(mutex);
@@ -207,6 +219,24 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
return std::nullopt;
}
+std::optional<nb::object>
+PyGlobals::lookupOpAdaptorClass(llvm::StringRef operationName) {
+ // Make sure dialect module is loaded.
+ auto split = operationName.split('.');
+ llvm::StringRef dialectNamespace = split.first;
+ if (!loadDialectModule(dialectNamespace))
+ return std::nullopt;
+
+ nb::ft_lock_guard lock(mutex);
+ auto foundIt = opAdaptorClassMap.find(operationName);
+ if (foundIt != opAdaptorClassMap.end()) {
+ assert(foundIt->second && "OpAdaptor is defined");
+ return foundIt->second;
+ }
+ // Not found and loading did not yield a registration.
+ return std::nullopt;
+}
+
bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
nanobind::ft_lock_guard lock(mutex);
return locTracebackEnabled_;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index eb00363a54034..c8693dc7624d6 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2774,6 +2774,28 @@ void populateRoot(nb::module_ &m) {
"dialect_class"_a, nb::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
+ m.def(
+ "register_op_adaptor",
+ [](const nb::type_object &opClass, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [opClass,
+ replace](nb::type_object adaptorClass) -> nb::type_object {
+ std::string operationName =
+ nb::cast<std::string>(adaptorClass.attr("OPERATION_NAME"));
+ PyGlobals::get().registerOpAdaptorImpl(operationName,
+ adaptorClass, replace);
+ // Dict-stuff the new adaptorClass by name onto the opClass.
+ opClass.attr("Adaptor") = adaptorClass;
+ return adaptorClass;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_op_adaptor(op_class: type, *, replace: bool = False) "
+ "-> typing.Callable[[type[T]], type[T]]"),
+ // clang-format on
+ "op_class"_a, nb::kw_only(), "replace"_a = false,
+ "Produce a class decorator for registering an OpAdaptor class for an "
+ "operation.");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
[](PyTypeID mlirTypeID, bool replace) -> nb::object {
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 2d44824c9d8fd..7d62b7cc943c6 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -112,26 +112,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
-// CHECK-LABEL: class AttributedOpAdaptor(_ods_OpAdaptor):
-// CHECK: OPERATION_NAME = "test.attributed_op"
-// CHECK: @builtins.property
-// CHECK: def i32attr(self) -> _ods_ir.IntegerAttr:
-// CHECK: return self.attributes["i32attr"]
-// CHECK: @builtins.property
-// CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
-// CHECK: if "optionalF32Attr" not in self.attributes:
-// CHECK: return None
-// CHECK: return self.attributes["optionalF32Attr"]
-// CHECK: @builtins.property
-// CHECK: def unitAttr(self) -> bool:
-// CHECK: return "unitAttr" in self.attributes
-// CHECK: @builtins.property
-// CHECK: def in_(self) -> _ods_ir.IntegerAttr:
-// CHECK: return self.attributes["in"]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class AttributedOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.attributed_op"
-// CHECK: Adaptor = AttributedOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def AttributedOp : TestOp<"attributed_op"> {
@@ -175,6 +158,23 @@ def AttributedOp : TestOp<"attributed_op"> {
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
UnitAttr:$unitAttr, I32Attr:$in);
}
+// CHECK: @_ods_cext.register_op_adaptor(AttributedOp)
+// CHECK-LABEL: class AttributedOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.attributed_op"
+// CHECK: @builtins.property
+// CHECK: def i32attr(self) -> _ods_ir.IntegerAttr:
+// CHECK: return self.attributes["i32attr"]
+// CHECK: @builtins.property
+// CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.FloatAttr]:
+// CHECK: if "optionalF32Attr" not in self.attributes:
+// CHECK: return None
+// CHECK: return self.attributes["optionalF32Attr"]
+// CHECK: @builtins.property
+// CHECK: def unitAttr(self) -> bool:
+// CHECK: return "unitAttr" in self.attributes
+// CHECK: @builtins.property
+// CHECK: def in_(self) -> _ods_ir.IntegerAttr:
+// CHECK: return self.attributes["in"]
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> AttributedOp:
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)
@@ -383,18 +383,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results
-// CHECK-LABEL: class OneOptionalOperandOpAdaptor(_ods_OpAdaptor):
-// CHECK: OPERATION_NAME = "test.one_optional_operand"
-// CHECK: @builtins.property
-// CHECK: def non_optional(self) -> _ods_ir.Value:
-// CHECK: return self.operands[0]
-// CHECK: @builtins.property
-// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
-// CHECK: return None if len(self.operands) < 2 else self.operands[1]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.one_optional_operand"
-// CHECK: Adaptor = OneOptionalOperandOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
@@ -420,23 +411,22 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
+// CHECK: @_ods_cext.register_op_adaptor(OneOptionalOperandOp)
+// CHECK-LABEL: class OneOptionalOperandOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.one_optional_operand"
+// CHECK: @builtins.property
+// CHECK: def non_optional(self) -> _ods_ir.Value:
+// CHECK: return self.operands[0]
+// CHECK: @builtins.property
+// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
+// CHECK: return None if len(self.operands) < 2 else self.operands[1]
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp:
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)
-// CHECK-LABEL: class OneVariadicOperandOpAdaptor(_ods_OpAdaptor):
-// CHECK: OPERATION_NAME = "test.one_variadic_operand"
-// CHECK: @builtins.property
-// CHECK: def non_variadic(self) -> _ods_ir.Value:
-// CHECK: return self.operands[0]
-// CHECK: @builtins.property
-// CHECK: def variadic(self) -> _ods_ir.OpOperandList:
-// CHECK: _ods_variadic_group_length = len(self.operands) - 2 + 1
-// CHECK: return self.operands[1:1 + _ods_variadic_group_length]
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK-LABEL: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.one_variadic_operand"
-// CHECK: Adaptor = OneVariadicOperandOpAdaptor
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
@@ -463,6 +453,16 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: return self.operation.operands[1:1 + _ods_variadic_group_length]
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}
+// CHECK: @_ods_cext.register_op_adaptor(OneVariadicOperandOp)
+// CHECK-LABEL: class OneVariadicOperandOpAdaptor(_ods_OpAdaptor):
+// CHECK: OPERATION_NAME = "test.one_variadic_operand"
+// CHECK: @builtins.property
+// CHECK: def non_variadic(self) -> _ods_ir.Value:
+// CHECK: return self.operands[0]
+// CHECK: @builtins.property
+// CHECK: def variadic(self) -> _ods_ir.OpOperandList:
+// CHECK: _ods_variadic_group_length = len(self.operands) - 2 + 1
+// CHECK: return self.operands[1:1 + _ods_variadic_group_length]
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> OneVariadicOperandOp:
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 74f6d7edea4c2..6391842567617 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -70,13 +70,13 @@ constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
class {0}(_ods_ir.OpView):{2}
OPERATION_NAME = "{1}"
- Adaptor = {0}Adaptor
)Py";
/// Template for operation class:
/// {0} is the Python class name;
/// {1} is the operation name。
constexpr const char *opAdaptorClassTemplate = R"Py(
+ at _ods_cext.register_op_adaptor({0})
class {0}Adaptor(_ods_OpAdaptor):
OPERATION_NAME = "{1}"
)Py";
@@ -1277,11 +1277,6 @@ static void emitAdaptorOperandAccessors(const Operator &op, raw_ostream &os) {
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
- os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
- op.getOperationName());
- emitAdaptorOperandAccessors(op, os);
- emitAdaptorAttributeAccessors(op, os);
-
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
makeDocStringForOp(op));
@@ -1299,6 +1294,12 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
+
+ os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
+ op.getOperationName());
+ emitAdaptorOperandAccessors(op, os);
+ emitAdaptorAttributeAccessors(op, os);
+
emitValueBuilder(op, functionArgs, os);
}
>From 0a04cb592b1bda89de74f07fef59bb29a722143b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 25 Jan 2026 00:01:29 +0800
Subject: [PATCH 06/13] add c apis
---
mlir/include/mlir-c/Rewrite.h | 116 +++++++++++++
mlir/include/mlir/CAPI/Rewrite.h | 6 +
.../mlir/Transforms/DialectConversion.h | 2 +
mlir/lib/CAPI/Transforms/Rewrite.cpp | 158 ++++++++++++++++++
4 files changed, 282 insertions(+)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 26f7f08535b41..f32a8d880a52f 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -59,6 +59,11 @@ typedef enum {
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
+DEFINE_C_API_STRUCT(MlirConversionTarget, void);
+DEFINE_C_API_STRUCT(MlirConversionPattern, const void);
+DEFINE_C_API_STRUCT(MlirTypeConverter, void);
+DEFINE_C_API_STRUCT(MlirConversionPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirConversionConfig, void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -423,6 +428,16 @@ MLIR_CAPI_EXPORTED void
mlirWalkAndApplyPatterns(MlirOperation op,
MlirFrozenRewritePatternSet patterns);
+/// Apply a partial conversion on the given operation.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPartialConversion(
+ MlirOperation op, MlirConversionTarget target,
+ MlirFrozenRewritePatternSet patterns, MlirConversionConfig config);
+
+/// Apply a full conversion on the given operation.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(
+ MlirOperation op, MlirConversionTarget target,
+ MlirFrozenRewritePatternSet patterns, MlirConversionConfig config);
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
@@ -431,6 +446,107 @@ mlirWalkAndApplyPatterns(MlirOperation op,
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+/// Cast the ConversionPatternRewriter to a PatternRewriter
+MLIR_CAPI_EXPORTED MlirPatternRewriter
+mlirConversionPatternRewriterAsPatternRewriter(
+ MlirConversionPatternRewriter rewriter);
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+/// Create an empty ConversionTarget.
+MLIR_CAPI_EXPORTED MlirConversionTarget
+mlirConversionTargetCreate(MlirContext context);
+
+/// Destroy the given ConversionTarget.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetDestroy(MlirConversionTarget target);
+
+/// Register the given operations as legal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+ MlirStringRef opName);
+
+/// Register the given operations as illegal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+ MlirStringRef opName);
+
+/// Register the operations of the given dialect as legal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName);
+
+/// Register the operations of the given dialect as illegal.
+MLIR_CAPI_EXPORTED void
+mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName);
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+/// Create a TypeConverter.
+MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate();
+
+/// Destroy the given TypeConverter.
+MLIR_CAPI_EXPORTED void
+mlirTypeConverterDestroy(MlirTypeConverter typeConverter);
+
+/// Callback type for type conversion functions.
+/// Returns failure or sets convertedType to MlirType{NULL} to indicate failure.
+/// If failure is returned, the converter is allowed to try another
+/// conversion function to perform the conversion.
+typedef MlirLogicalResult (*MlirTypeConverterConversionCallback)(
+ MlirType type, MlirType *convertedType, void *userData);
+
+/// Add a type conversion function to the given TypeConverter.
+MLIR_CAPI_EXPORTED void
+mlirTypeConverterAddConversion(MlirTypeConverter typeConverter,
+ MlirTypeConverterConversionCallback convertType,
+ void *userData);
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+ /// Optional constructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*construct)(void *userData);
+ /// Optional destructor for the user data.
+ /// Set to nullptr to disable it.
+ void (*destruct)(void *userData);
+ /// The callback function to match against code rooted at the specified
+ /// operation, and perform the conversion rewrite if the match is successful,
+ /// corresponding to ConversionPattern::matchAndRewrite.
+ MlirLogicalResult (*matchAndRewrite)(MlirConversionPattern pattern,
+ MlirOperation op, intptr_t nOperands,
+ MlirValue *operands,
+ MlirConversionPatternRewriter rewriter,
+ void *userData);
+} MlirConversionPatternCallbacks;
+
+/// Create a conversion pattern that matches the operation with the given
+/// rootName, corresponding to mlir::OpConversionPattern.
+MLIR_CAPI_EXPORTED MlirConversionPattern mlirOpConversionPatternCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+ void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+/// Get the type converter used by this conversion pattern.
+MLIR_CAPI_EXPORTED MlirTypeConverter
+mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern);
+
+/// Cast the ConversionPattern to a RewritePattern.
+MLIR_CAPI_EXPORTED MlirRewritePattern
+mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern);
+
//===----------------------------------------------------------------------===//
/// RewritePattern API
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 9c96d354d4fc9..41697eca90cc0 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -19,6 +19,7 @@
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase)
DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern)
@@ -26,6 +27,11 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
mlir::FrozenRewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirConversionTarget, mlir::ConversionTarget)
+DEFINE_C_API_PTR_METHODS(MlirConversionPattern, const mlir::ConversionPattern)
+DEFINE_C_API_PTR_METHODS(MlirTypeConverter, mlir::TypeConverter)
+DEFINE_C_API_PTR_METHODS(MlirConversionPatternRewriter,
+ mlir::ConversionPatternRewriter)
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9f449080b0f37..0f67b9eceab59 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1108,6 +1108,8 @@ class ConversionTarget {
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
+ MLIRContext &getContext() const { return ctx; }
+
//===--------------------------------------------------------------------===//
// Legality Registration
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 798ca1de651c1..af499dc4d80f3 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -8,6 +8,7 @@
#include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
#include "mlir-c/Transforms.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Rewrite.h"
@@ -17,6 +18,7 @@
#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -439,6 +441,22 @@ void mlirWalkAndApplyPatterns(MlirOperation op,
mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
}
+MlirLogicalResult
+mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target,
+ MlirFrozenRewritePatternSet patterns,
+ MlirConversionConfig config) {
+ return wrap(mlir::applyPartialConversion(unwrap(op), *unwrap(target),
+ *unwrap(patterns)));
+}
+
+MlirLogicalResult mlirApplyFullConversion(MlirOperation op,
+ MlirConversionTarget target,
+ MlirFrozenRewritePatternSet patterns,
+ MlirConversionConfig config) {
+ return wrap(mlir::applyFullConversion(unwrap(op), *unwrap(target),
+ *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
@@ -447,6 +465,146 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
+//===----------------------------------------------------------------------===//
+/// ConversionPatternRewriter API
+//===----------------------------------------------------------------------===//
+
+MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
+ MlirConversionPatternRewriter rewriter) {
+ return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionTarget API
+//===----------------------------------------------------------------------===//
+
+MlirConversionTarget mlirConversionTargetCreate(MlirContext context) {
+ return wrap(new mlir::ConversionTarget(*unwrap(context)));
+}
+
+void mlirConversionTargetDestroy(MlirConversionTarget target) {
+ delete unwrap(target);
+}
+
+void mlirConversionTargetAddLegalOp(MlirConversionTarget target,
+ MlirStringRef opName) {
+ unwrap(target)->addLegalOp(
+ mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
+ MlirStringRef opName) {
+ unwrap(target)->addIllegalOp(
+ mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
+}
+
+void mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName) {
+ unwrap(target)->addLegalDialect(unwrap(dialectName));
+}
+
+void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
+ MlirStringRef dialectName) {
+ unwrap(target)->addIllegalDialect(unwrap(dialectName));
+}
+
+//===----------------------------------------------------------------------===//
+/// TypeConverter API
+//===----------------------------------------------------------------------===//
+
+MlirTypeConverter mlirTypeConverterCreate() {
+ return wrap(new mlir::TypeConverter());
+}
+
+void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) {
+ delete unwrap(typeConverter);
+}
+
+void mlirTypeConverterAddConversion(
+ MlirTypeConverter typeConverter,
+ MlirTypeConverterConversionCallback convertType, void *userData) {
+ unwrap(typeConverter)
+ ->addConversion(
+ [convertType, userData](Type type) -> std::optional<Type> {
+ MlirType converted{nullptr};
+ MlirLogicalResult result =
+ convertType(wrap(type), &converted, userData);
+ if (mlirLogicalResultIsFailure(result))
+ return std::nullopt; // allowed to try another conversion function
+ if (mlirTypeIsNull(converted))
+ return nullptr;
+ return unwrap(converted);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionPattern API
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+
+class ExternalConversionPattern : public mlir::ConversionPattern {
+public:
+ ExternalConversionPattern(MlirConversionPatternCallbacks callbacks,
+ void *userData, StringRef rootName,
+ PatternBenefit benefit, MLIRContext *context,
+ TypeConverter *typeConverter,
+ ArrayRef<StringRef> generatedNames)
+ : ConversionPattern(*typeConverter, rootName, benefit, context,
+ generatedNames),
+ callbacks(callbacks), userData(userData) {
+ if (callbacks.construct)
+ callbacks.construct(userData);
+ }
+
+ ~ExternalConversionPattern() {
+ if (callbacks.destruct)
+ callbacks.destruct(userData);
+ }
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ std::vector<MlirValue> wrappedOperands;
+ for (Value val : operands)
+ wrappedOperands.push_back(wrap(val));
+ return unwrap(callbacks.matchAndRewrite(
+ wrap(static_cast<const mlir::ConversionPattern *>(this)), wrap(op),
+ wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter),
+ userData));
+ }
+
+private:
+ MlirConversionPatternCallbacks callbacks;
+ void *userData;
+};
+
+} // namespace mlir
+
+MlirConversionPattern mlirOpConversionPatternCreate(
+ MlirStringRef rootName, unsigned benefit, MlirContext context,
+ MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
+ void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) {
+ std::vector<mlir::StringRef> generatedNamesVec;
+ generatedNamesVec.reserve(nGeneratedNames);
+ for (size_t i = 0; i < nGeneratedNames; ++i) {
+ generatedNamesVec.push_back(unwrap(generatedNames[i]));
+ }
+ return wrap(new mlir::ExternalConversionPattern(
+ callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+ unwrap(context), unwrap(typeConverter), generatedNamesVec));
+}
+
+MlirTypeConverter
+mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern) {
+ return wrap(const_cast<TypeConverter *>(unwrap(pattern)->getTypeConverter()));
+}
+
+MlirRewritePattern
+mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern) {
+ return wrap(static_cast<const RewritePattern *>(unwrap(pattern)));
+}
+
//===----------------------------------------------------------------------===//
/// RewritePattern API
//===----------------------------------------------------------------------===//
>From b5a9f4b2a059b0ce72c8e354e08582751e17523c Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 25 Jan 2026 00:19:38 +0800
Subject: [PATCH 07/13] add more c apis
---
mlir/include/mlir-c/Rewrite.h | 35 ++++++++++++++++++
mlir/include/mlir/CAPI/Rewrite.h | 3 +-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 55 +++++++++++++++++++++++++++-
3 files changed, 90 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index f32a8d880a52f..837219a236b64 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -438,6 +438,41 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyFullConversion(
MlirOperation op, MlirConversionTarget target,
MlirFrozenRewritePatternSet patterns, MlirConversionConfig config);
+//===----------------------------------------------------------------------===//
+/// ConversionConfig API
+//===----------------------------------------------------------------------===//
+
+/// Create a default ConversionConfig.
+MLIR_CAPI_EXPORTED MlirConversionConfig mlirConversionConfigCreate(void);
+
+/// Destroy the given ConversionConfig.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigDestroy(MlirConversionConfig config);
+
+typedef enum {
+ MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
+ MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
+ MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
+} MlirDialectConversionFoldingMode;
+
+/// Set the folding mode for the given ConversionConfig.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigSetFoldingMode(MlirConversionConfig config,
+ MlirDialectConversionFoldingMode mode);
+
+/// Get the folding mode for the given ConversionConfig.
+MLIR_CAPI_EXPORTED MlirDialectConversionFoldingMode
+mlirConversionConfigGetFoldingMode(MlirConversionConfig config);
+
+/// Enable or disable building materializations during conversion.
+MLIR_CAPI_EXPORTED void
+mlirConversionConfigEnableBuildMaterializations(MlirConversionConfig config,
+ bool enable);
+
+/// Check if building materializations during conversion is enabled.
+MLIR_CAPI_EXPORTED bool
+mlirConversionConfigIsBuildMaterializationsEnabled(MlirConversionConfig config);
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 41697eca90cc0..a172f59b3e3ce 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -32,6 +32,7 @@ DEFINE_C_API_PTR_METHODS(MlirConversionPattern, const mlir::ConversionPattern)
DEFINE_C_API_PTR_METHODS(MlirTypeConverter, mlir::TypeConverter)
DEFINE_C_API_PTR_METHODS(MlirConversionPatternRewriter,
mlir::ConversionPatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirConversionConfig, mlir::ConversionConfig)
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
@@ -39,4 +40,4 @@ DEFINE_C_API_PTR_METHODS(MlirPDLResultList, mlir::PDLResultList)
DEFINE_C_API_PTR_METHODS(MlirPDLValue, const mlir::PDLValue)
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
-#endif // MLIR_CAPIREWRITER_H
+#endif // MLIR_CAPI_REWRITE_H
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index af499dc4d80f3..24893aac7c50c 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -446,7 +446,7 @@ mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target,
MlirFrozenRewritePatternSet patterns,
MlirConversionConfig config) {
return wrap(mlir::applyPartialConversion(unwrap(op), *unwrap(target),
- *unwrap(patterns)));
+ *unwrap(patterns), *unwrap(config)));
}
MlirLogicalResult mlirApplyFullConversion(MlirOperation op,
@@ -454,7 +454,58 @@ MlirLogicalResult mlirApplyFullConversion(MlirOperation op,
MlirFrozenRewritePatternSet patterns,
MlirConversionConfig config) {
return wrap(mlir::applyFullConversion(unwrap(op), *unwrap(target),
- *unwrap(patterns)));
+ *unwrap(patterns), *unwrap(config)));
+}
+
+//===----------------------------------------------------------------------===//
+/// ConversionConfig API
+//===----------------------------------------------------------------------===//
+
+MlirConversionConfig mlirConversionConfigCreate(void) {
+ return wrap(new mlir::ConversionConfig());
+}
+
+void mlirConversionConfigDestroy(MlirConversionConfig config) {
+ delete unwrap(config);
+}
+
+void mlirConversionConfigSetFoldingMode(MlirConversionConfig config,
+ MlirDialectConversionFoldingMode mode) {
+ mlir::DialectConversionFoldingMode cppMode;
+ switch (mode) {
+ case MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER:
+ cppMode = mlir::DialectConversionFoldingMode::Never;
+ break;
+ case MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS:
+ cppMode = mlir::DialectConversionFoldingMode::BeforePatterns;
+ break;
+ case MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS:
+ cppMode = mlir::DialectConversionFoldingMode::AfterPatterns;
+ break;
+ }
+ unwrap(config)->foldingMode = cppMode;
+}
+
+MlirDialectConversionFoldingMode
+mlirConversionConfigGetFoldingMode(MlirConversionConfig config) {
+ switch (unwrap(config)->foldingMode) {
+ case mlir::DialectConversionFoldingMode::Never:
+ return MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER;
+ case mlir::DialectConversionFoldingMode::BeforePatterns:
+ return MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS;
+ case mlir::DialectConversionFoldingMode::AfterPatterns:
+ return MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS;
+ }
+}
+
+void mlirConversionConfigEnableBuildMaterializations(
+ MlirConversionConfig config, bool enable) {
+ unwrap(config)->buildMaterializations = enable;
+}
+
+bool mlirConversionConfigIsBuildMaterializationsEnabled(
+ MlirConversionConfig config) {
+ return unwrap(config)->buildMaterializations;
}
//===----------------------------------------------------------------------===//
>From b1a340b353c8f6264ef4f05908a78a94de6fe025 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 25 Jan 2026 09:48:18 +0800
Subject: [PATCH 08/13] fix
---
mlir/include/mlir-c/Rewrite.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 837219a236b64..b4f93fd5a9b78 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -527,7 +527,7 @@ mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
//===----------------------------------------------------------------------===//
/// Create a TypeConverter.
-MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate();
+MLIR_CAPI_EXPORTED MlirTypeConverter mlirTypeConverterCreate(void);
/// Destroy the given TypeConverter.
MLIR_CAPI_EXPORTED void
>From cc12c63255af9053a27261e9549d6b0877e17012 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 25 Jan 2026 13:46:03 +0800
Subject: [PATCH 09/13] add python apis
---
mlir/lib/Bindings/Python/Rewrite.cpp | 306 ++++++++++++++++++++++-
mlir/python/mlir/dialects/_ods_common.py | 23 +-
mlir/test/python/rewrite.py | 58 +++++
3 files changed, 377 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2b649f79c5982..1b599f425215e 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -11,6 +11,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
@@ -18,6 +19,7 @@
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
+#include <type_traits>
namespace nb = nanobind;
using namespace mlir;
@@ -62,9 +64,94 @@ class PyPatternRewriter {
PyMlirContextRef ctx;
};
-struct PyMlirPDLResultList : MlirPDLResultList {};
+class PyConversionPatternRewriter : PyPatternRewriter {
+public:
+ PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
+ : PyPatternRewriter(
+ mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {}
+};
+
+class PyConversionTarget {
+public:
+ PyConversionTarget(MlirContext context)
+ : target(mlirConversionTargetCreate(context)) {}
+ ~PyConversionTarget() { mlirConversionTargetDestroy(target); }
+
+ void addLegalOp(const std::string &opName) {
+ mlirConversionTargetAddLegalOp(
+ target, mlirStringRefCreate(opName.data(), opName.size()));
+ }
+
+ void addIllegalOp(const std::string &opName) {
+ mlirConversionTargetAddIllegalOp(
+ target, mlirStringRefCreate(opName.data(), opName.size()));
+ }
+
+ void addLegalDialect(const std::string &dialectName) {
+ mlirConversionTargetAddLegalDialect(
+ target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
+ }
+
+ void addIllegalDialect(const std::string &dialectName) {
+ mlirConversionTargetAddIllegalDialect(
+ target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
+ }
+
+ MlirConversionTarget get() { return target; }
+
+private:
+ MlirConversionTarget target;
+};
+
+class PyTypeConverter {
+public:
+ PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
+ PyTypeConverter(MlirTypeConverter typeConverter)
+ : typeConverter(typeConverter), owner(false) {}
+ ~PyTypeConverter() {
+ if (owner)
+ mlirTypeConverterDestroy(typeConverter);
+ }
+
+ void addConversion(const nb::callable &convert) {
+ mlirTypeConverterAddConversion(
+ typeConverter,
+ [](MlirType type, MlirType *converted,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f = nb::handle(static_cast<PyObject *>(userData));
+ auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type));
+ nb::object res = f(PyType(ctx, type).maybeDownCast());
+ if (res.is_none())
+ return mlirLogicalResultFailure();
+
+ *converted = nb::cast<PyType>(res).get();
+ return mlirLogicalResultSuccess();
+ },
+ convert.ptr());
+ }
+
+ MlirTypeConverter get() { return typeConverter; }
+
+private:
+ MlirTypeConverter typeConverter;
+ bool owner;
+};
+
+class PyConversionPattern {
+public:
+ PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
+
+ PyTypeConverter getTypeConverter() {
+ return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern));
+ }
+
+private:
+ MlirConversionPattern pattern;
+};
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+struct PyMlirPDLResultList : MlirPDLResultList {};
+
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
return nb::cast(v);
@@ -216,6 +303,46 @@ class PyRewritePatternSet {
mlirRewritePatternSetAdd(set, pattern);
}
+ void addConversion(MlirStringRef rootName, unsigned benefit,
+ const nb::callable &matchAndRewrite,
+ PyTypeConverter &typeConverter) {
+ MlirConversionPatternCallbacks callbacks;
+ callbacks.construct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ };
+ callbacks.destruct = [](void *userData) {
+ nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ };
+ callbacks.matchAndRewrite =
+ [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
+ MlirValue *operands, MlirConversionPatternRewriter rewriter,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f(static_cast<PyObject *>(userData));
+
+ PyMlirContextRef ctx =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+ std::vector<MlirValue> operandsVec(operands, operands + nOperands);
+ std::optional<nb::object> adaptorCls =
+ PyGlobals::get().lookupOpAdaptorClass(
+ unwrap(mlirIdentifierStr(mlirOperationGetName(op))));
+ assert(adaptorCls && "cannot found adaptor for this operation");
+
+ nb::object res = f(opView, adaptorCls.value()(operandsVec, opView),
+ PyConversionPattern(pattern).getTypeConverter(),
+ PyConversionPatternRewriter(rewriter));
+ return logicalResultFromObject(res);
+ };
+ MlirConversionPattern pattern = mlirOpConversionPatternCreate(
+ rootName, benefit, ctx, typeConverter.get(), callbacks,
+ matchAndRewrite.ptr(),
+ /* nGeneratedNames */ 0,
+ /* generatedNames */ nullptr);
+ mlirRewritePatternSetAdd(set,
+ mlirConversionPatternAsRewritePattern(pattern));
+ }
+
PyFrozenRewritePatternSet freeze() {
MlirRewritePatternSet s = set;
set.ptr = nullptr;
@@ -324,6 +451,46 @@ class PyGreedyRewriteConfig {
}
};
+enum class PyDialectConversionFoldingMode : std::underlying_type_t<
+ MlirDialectConversionFoldingMode> {
+ Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
+ BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
+ AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
+};
+
+class PyConversionConfig {
+public:
+ PyConversionConfig()
+ : config(mlirConversionConfigCreate().ptr,
+ PyConversionConfig::customDeleter) {}
+
+ MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
+
+ void setFoldingMode(PyDialectConversionFoldingMode mode) {
+ mlirConversionConfigSetFoldingMode(get(),
+ MlirDialectConversionFoldingMode(mode));
+ }
+
+ PyDialectConversionFoldingMode getFoldingMode() {
+ return PyDialectConversionFoldingMode(
+ mlirConversionConfigGetFoldingMode(get()));
+ }
+
+ void enableBuildMaterializations(bool enabled) {
+ mlirConversionConfigEnableBuildMaterializations(get(), enabled);
+ }
+
+ bool isBuildMaterializationsEnabled() {
+ return mlirConversionConfigIsBuildMaterializationsEnabled(get());
+ }
+
+private:
+ std::shared_ptr<void> config;
+ static void customDeleter(void *c) {
+ mlirConversionConfigDestroy(MlirConversionConfig{c});
+ }
+};
+
/// Create the `mlir.rewrite` here.
void populateRewriteSubmodule(nb::module_ &m) {
// Enum definitions
@@ -337,6 +504,12 @@ void populateRewriteSubmodule(nb::module_ &m) {
.value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
.value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
.value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
+
+ nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
+ .value("NEVER", PyDialectConversionFoldingMode::Never)
+ .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
+ .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
+
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
@@ -409,9 +582,97 @@ void populateRewriteSubmodule(nb::module_ &m) {
If possible, the operation is cast to its corresponding OpView subclass
before being passed to the callable.
benefit: The benefit of the pattern, defaulting to 1.)")
+ .def(
+ "add_conversion",
+ [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+ PyTypeConverter &typeConverter, unsigned benefit) {
+ std::string opName =
+ nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ self.addConversion(
+ mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
+ typeConverter);
+ },
+ "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
+ R"(
+ Add a new conversion pattern on the specified root operation,
+ using the provided callable for matching and rewriting,
+ and assign it the given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
+ an operation name string (e.g., ``"arith.addi"``).
+ fn: The callable to use for matching and rewriting,
+ which takes an operation, its adaptor,
+ the type converter and a pattern rewriter as arguments.
+ The match is considered successful iff the callable returns
+ a value where ``bool(value)`` is ``False`` (e.g. ``None``).
+ If possible, the operation is cast to its corresponding OpView subclass
+ before being passed to the callable.
+ type_converter: The type converter to convert types in the IR.
+ benefit: The benefit of the pattern, defaulting to 1.)")
.def("freeze", &PyRewritePatternSet::freeze,
"Freeze the pattern set into a frozen one.");
+ nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
+ m, "ConversionPatternRewriter");
+
+ nb::class_<PyConversionTarget>(m, "ConversionTarget")
+ .def(
+ "__init__",
+ [](PyConversionTarget &self, DefaultingPyMlirContext context) {
+ new (&self) PyConversionTarget(context.get()->get());
+ },
+ "context"_a = nb::none())
+ .def(
+ "add_legal_op",
+ [](PyConversionTarget &self, const nb::args &ops) {
+ for (auto op : ops) {
+ std::string opName =
+ nb::cast<std::string>(op.attr("OPERATION_NAME"));
+ self.addLegalOp(opName);
+ }
+ },
+ "ops"_a, "Mark the given operations as legal.")
+ .def(
+ "add_illegal_op",
+ [](PyConversionTarget &self, const nb::args &ops) {
+ for (auto op : ops) {
+ std::string opName =
+ nb::cast<std::string>(op.attr("OPERATION_NAME"));
+ self.addIllegalOp(opName);
+ }
+ },
+ "ops"_a, "Mark the given operations as illegal.")
+ .def(
+ "add_legal_dialect",
+ [](PyConversionTarget &self, const nb::args &dialects) {
+ for (auto dialect : dialects) {
+ std::string dialectName =
+ nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
+ self.addLegalDialect(dialectName);
+ }
+ },
+ "dialects"_a, "Mark the given dialects as legal.")
+ .def(
+ "add_illegal_dialect",
+ [](PyConversionTarget &self, const nb::args &dialects) {
+ for (auto dialect : dialects) {
+ std::string dialectName =
+ nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
+ self.addIllegalDialect(dialectName);
+ }
+ },
+ "dialects"_a, "Mark the given dialect as illegal.");
+
+ nb::class_<PyTypeConverter>(m, "TypeConverter")
+ .def(
+ "__init__",
+ [](PyTypeConverter &self) { new (&self) PyTypeConverter(); },
+ "Create a new TypeConverter.")
+ .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
+ nb::keep_alive<0, 1>(), "Register a type conversion function.");
+
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
@@ -498,6 +759,17 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyGreedyRewriteConfig::enableConstantCSE,
"Enable or disable constant CSE");
+ nb::class_<PyConversionConfig>(m, "ConversionConfig")
+ .def(nb::init<>(), "Create a conversion config with defaults")
+ .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
+ &PyConversionConfig::setFoldingMode,
+ "folding behavior during dialect conversion")
+ .def_prop_rw("build_materializations",
+ &PyConversionConfig::isBuildMaterializationsEnabled,
+ &PyConversionConfig::enableBuildMaterializations,
+ "Whether the dialect conversion attempts to build "
+ "source/target materializations");
+
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
@@ -539,7 +811,37 @@ void populateRewriteSubmodule(nb::module_ &m) {
},
"op"_a, "set"_a,
"Applies the given patterns to the given op by a fast walk-based "
- "driver.");
+ "driver.")
+ .def(
+ "apply_partial_conversion",
+ [](PyOperationBase &op, PyConversionTarget &target,
+ PyFrozenRewritePatternSet &set,
+ std::optional<PyConversionConfig> config) {
+ if (!config) {
+ config.emplace(PyConversionConfig());
+ }
+ auto status = mlirApplyPartialConversion(
+ op.getOperation(), target.get(), set.get(), config->get());
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error("partial conversion failed");
+ },
+ "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
+ "Applies a partial conversion on the given operation.")
+ .def(
+ "apply_full_conversion",
+ [](PyOperationBase &op, PyConversionTarget &target,
+ PyFrozenRewritePatternSet &set,
+ std::optional<PyConversionConfig> config) {
+ if (!config) {
+ config.emplace(PyConversionConfig());
+ }
+ auto status = mlirApplyFullConversion(
+ op.getOperation(), target.get(), set.get(), config->get());
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error("full conversion failed");
+ },
+ "op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
+ "Applies a full conversion on the given operation.");
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index e8b7aa81ef920..700ef1cb7a9b4 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -23,6 +23,8 @@
Operation,
ShapedType,
Value,
+ OpView,
+ OpAttributeMap,
)
__all__ = [
@@ -90,7 +92,7 @@ def get_default_loc_context(location=None):
def get_op_result_or_value(
arg: _Union[
_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
- ]
+ ],
) -> _cext.ir.Value:
"""Returns the given value or the single result of the given op.
@@ -114,7 +116,7 @@ def get_op_results_or_values(
_cext.ir.OpView,
_cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
- ]
+ ],
) -> _Union[
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
_cext.ir.OpResultList,
@@ -249,7 +251,7 @@ def _dispatch_mixed_values(
def _get_value_or_attribute_value(
- value_or_attr: _Union[any, Attribute, ArrayAttr]
+ value_or_attr: _Union[any, Attribute, ArrayAttr],
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
@@ -259,13 +261,13 @@ def _get_value_or_attribute_value(
def _get_value_list(
- sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
+ sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr],
) -> _Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(
- values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
+ values: _Optional[_Union[ArrayAttr, IntOrAttrList]],
) -> ArrayAttr:
if values is None:
return None
@@ -280,7 +282,7 @@ def _get_int_array_attr(
def _get_int_array_array_attr(
- values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
+ values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]],
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
@@ -308,6 +310,11 @@ def _get_int_array_array_attr(
class OpAdaptor:
- def __init__(self, operands, attributes) -> None:
+ def __init__(
+ self, operands: _Sequence[Value], opview_or_attributes: OpView | OpAttributeMap
+ ) -> None:
self.operands = operands
- self.attributes = attributes
+ if isinstance(opview_or_attributes, OpView):
+ self.attributes = opview_or_attributes.operation.attributes
+ else:
+ self.attributes = opview_or_attributes
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 8ef49981a8b3c..f164b21a04d55 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -226,3 +226,61 @@ def constant_1_to_2(op, rewriter):
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64 : i64
print(module)
+
+
+ at run
+def testConversionPattern():
+ from mlir.dialects import smt
+
+ def convert_int(t):
+ if isinstance(t, IntegerType):
+ return smt.IntType.get()
+
+ converter = TypeConverter()
+ converter.add_conversion(convert_int)
+
+ def convert_constant(op, adaptor, type_converter, rewriter):
+ with rewriter.ip:
+ new_op = smt.IntConstantOp(op.value, loc=op.location)
+ rewriter.replace_op(op, new_op)
+
+ def convert_addi(op, adaptor, type_converter, rewriter):
+ with rewriter.ip:
+ new_op = smt.IntAddOp([adaptor.lhs, adaptor.rhs], loc=op.location)
+ rewriter.replace_op(op, new_op)
+
+ def convert_muli(op, adaptor, type_converter, rewriter):
+ with rewriter.ip:
+ new_op = smt.IntMulOp([adaptor.lhs, adaptor.rhs], loc=op.location)
+ rewriter.replace_op(op, new_op)
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add_conversion(arith.ConstantOp, convert_constant, converter)
+ patterns.add_conversion(arith.AddIOp, convert_addi, converter)
+ patterns.add_conversion(arith.MulIOp, convert_muli, converter)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @f(%0: i64) -> i64 {
+ %1 = arith.constant 3 : i64
+ %2 = arith.addi %0, %1 : i64
+ %3 = arith.muli %2, %1 : i64
+ return %3 : i64
+ }
+ }
+ """
+ )
+
+ target = ConversionTarget()
+ target.add_legal_dialect(smt._Dialect)
+ target.add_illegal_op(arith.ConstantOp, arith.AddIOp, arith.MulIOp)
+
+ frozen = patterns.freeze()
+ config = ConversionConfig()
+ config.build_materializations = False
+
+ apply_partial_conversion(module, target, frozen, config)
+ assert module.operation.verify()
+ print(module)
>From 06bcdb96230ff61d5eba7ebf597b7dbc022c873f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 25 Jan 2026 13:52:35 +0800
Subject: [PATCH 10/13] add checks
---
mlir/test/python/rewrite.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index f164b21a04d55..aaed31b81aef6 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -283,4 +283,13 @@ def convert_muli(op, adaptor, type_converter, rewriter):
apply_partial_conversion(module, target, frozen, config)
assert module.operation.verify()
+
+ # CHECK: func.func @f(%arg0: i64) -> i64 {
+ # CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : i64 to !smt.int
+ # CHECK: %c3 = smt.int.constant 3
+ # CHECK: %1 = smt.int.add %0, %c3
+ # CHECK: %2 = smt.int.mul %1, %c3
+ # CHECK: %3 = builtin.unrealized_conversion_cast %2 : !smt.int to i64
+ # CHECK: return %3 : i64
+ # CHECK: }
print(module)
>From 01ac4380aa4e2de6857d868de2031ee59499b9ae Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Wed, 28 Jan 2026 22:13:35 +0800
Subject: [PATCH 11/13] Update mlir/include/mlir/Bindings/Python/Globals.h
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/include/mlir/Bindings/Python/Globals.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index ad58c9374f766..6a722575c4e48 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -94,7 +94,7 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
void registerOperationImpl(const std::string &operationName,
nanobind::object pyClass, bool replace = false);
- /// Adds a operation adaptor class.
+ /// Adds an operation adaptor class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOpAdaptorImpl(const std::string &operationName,
>From b54a01ab1fb33d804852924e19885ba861c93255 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Wed, 28 Jan 2026 22:21:47 +0800
Subject: [PATCH 12/13] Update mlir/lib/Bindings/Python/Rewrite.cpp
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 1b599f425215e..137d97b20855b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -820,7 +820,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
if (!config) {
config.emplace(PyConversionConfig());
}
- auto status = mlirApplyPartialConversion(
+ MlirLogicalResult status = mlirApplyPartialConversion(
op.getOperation(), target.get(), set.get(), config->get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("partial conversion failed");
>From 57fd1a31bd955fb410ece7d67a55407d6cf44630 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Wed, 28 Jan 2026 22:21:56 +0800
Subject: [PATCH 13/13] Update mlir/lib/Bindings/Python/Rewrite.cpp
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/lib/Bindings/Python/Rewrite.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 137d97b20855b..9a550a7df4547 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -817,9 +817,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
[](PyOperationBase &op, PyConversionTarget &target,
PyFrozenRewritePatternSet &set,
std::optional<PyConversionConfig> config) {
- if (!config) {
+ if (!config)
config.emplace(PyConversionConfig());
- }
MlirLogicalResult status = mlirApplyPartialConversion(
op.getOperation(), target.get(), set.get(), config->get());
if (mlirLogicalResultIsFailure(status))
More information about the Mlir-commits
mailing list